Almost have sync working

This commit is contained in:
Mariano Sorgente 2020-03-06 17:23:03 +09:00
parent 3f0c59336f
commit 284faa46e8
No known key found for this signature in database
GPG Key ID: 0F866338C369278C
8 changed files with 641 additions and 154 deletions

View File

@ -4,7 +4,7 @@ ping_interval: 120
# Controls logging of all servers (harvester, farmer, etc..). Each one can be overriden.
logging: &logging
log_stdout: False # If True, outputs to stdout instead of a file
log_stdout: True # If True, outputs to stdout instead of a file
log_filename: "chia.log"
harvester:
@ -124,9 +124,14 @@ introducer:
wallet:
host: 127.0.0.1
port: 8223
port: 8449
rpc_port: 9256
# The minimum height that we care about for our transactions. Set to zero
# If we are restoring from private key and don't know the height.
starting_height: 0
num_sync_batches: 10
full_node_peer:
host: 127.0.0.1
port: 8444

View File

@ -1821,7 +1821,7 @@ class FullNode:
)
yield OutboundMessage(
NodeType.WALLET,
Message("respond_block", response),
Message("respond_header", response),
Delivery.RESPOND,
)
return

View File

@ -41,7 +41,7 @@ async def main():
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all)
_ = await server.start_server(config["host"], wallet._on_connect)
_ = await server.start_server(config["host"], None)
await asyncio.sleep(1)
_ = await server.start_client(full_node_peer, None)

View File

@ -1,17 +1,28 @@
from pathlib import Path
import asyncio
from typing import Dict, Optional
from blspy import ExtendedPrivateKey
import time
from typing import Dict, Optional, Tuple, List
import concurrent
import logging
from blspy import ExtendedPrivateKey
from src.util.merkle_set import (
confirm_included_already_hashed,
confirm_not_included_already_hashed,
)
from src.protocols import wallet_protocol
from src.consensus.constants import constants as consensus_constants
from src.server.server import ChiaServer
from src.server.outbound_message import OutboundMessage, NodeType, Message, Delivery
from src.util.ints import uint32
from src.util.ints import uint32, uint64
from src.types.sized_bytes import bytes32
from src.util.api_decorators import api_request
from src.wallet.wallet import Wallet
from src.wallet.wallet_state_manager import WalletStateManager
from src.wallet.block_record import BlockRecord
from src.types.header_block import HeaderBlock
from src.types.full_block import FullBlock
from src.types.hashable.coin import Coin, hash_coin_list
from src.full_node.blockchain import ReceiveBlockResult
class WalletNode:
@ -22,12 +33,25 @@ class WalletNode:
wallet_state_manager: WalletStateManager
log: logging.Logger
wallet: Wallet
cached_blocks: Dict[bytes32, Tuple[BlockRecord, HeaderBlock]]
cached_removals: Dict[bytes32, List[bytes32]]
cached_additions: Dict[bytes32, List[Coin]]
proof_hashes: List[Tuple[bytes32, Optional[uint64]]]
header_hashes: List[bytes32]
potential_blocks_received: Dict[uint32, asyncio.Event]
potential_header_hashes: Dict[uint32, bytes32]
constants: Dict
short_sync_threshold: int
sync_mode: bool
_shut_down: bool
@staticmethod
async def create(
config: Dict, key_config: Dict, name: str = None, override_constants: Dict = {}
config: Dict,
key_config: Dict,
name: str = None,
db_path=None,
override_constants: Dict = {},
):
self = WalletNode()
self.config = config
@ -43,15 +67,31 @@ class WalletNode:
self.log = logging.getLogger(__name__)
pub_hex = self.private_key.get_public_key().serialize().hex()
path = Path(f"wallet_db_{pub_hex}.db")
if not db_path:
path = Path(f"wallet_db_{pub_hex}.db")
else:
path = db_path
self.wallet_state_manager = await WalletStateManager.create(
config, path, override_constants=override_constants,
config, path, override_constants=self.constants,
)
self.wallet = await Wallet.create(config, key_config, self.wallet_state_manager)
self.server = None
# Normal operation data
self.cached_blocks = {}
self.cached_removals = {}
self.cached_additions = {}
# Sync data
self.sync_mode = False
self._shut_down = False
self.proof_hashes = []
self.header_hashes = []
self.short_sync_threshold = 10
self.potential_blocks_received = {}
self.potential_header_hashes = {}
self.server = None
return self
@ -59,13 +99,262 @@ class WalletNode:
self.server = server
self.wallet.set_server(server)
def _shutdown(self):
self._shut_down = True
async def _sync(self):
"""
Wallet has fallen far behind (or is starting up for the first time), and must be synced
up to the tip of the blockchain
"""
# TODO(mariano): implement
pass
# 1. Get all header hashes
self.header_hashes = []
self.proof_hashes = []
self.potential_header_hashes = {}
genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"])
genesis_challenge = genesis.proof_of_space.challenge_hash
request_header_hashes = wallet_protocol.RequestAllHeaderHashesAfter(
uint32(0), genesis_challenge
)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("request_all_header_hashes_after", request_header_hashes),
Delivery.RESPOND,
)
timeout = 100
start_wait = time.time()
while time.time() - start_wait < timeout:
if self._shut_down:
return
if len(self.header_hashes) > 0:
break
await asyncio.sleep(0.5)
if len(self.header_hashes) == 0:
raise TimeoutError("Took too long to fetch header hashes.")
# 2. Find fork point
fork_point_height: uint32 = self.wallet_state_manager.find_fork_point(
self.header_hashes
)
fork_point_hash: bytes32 = self.header_hashes[fork_point_height]
self.log.info(f"Fork point: {fork_point_hash} at height {fork_point_height}")
tip_height = (
len(self.header_hashes) - 5
if len(self.header_hashes) > 5
else len(self.header_hashes)
)
header_validate_start_height: uint32
if self.config["starting_height"] == 0:
header_validate_start_height = fork_point_height
else:
# Request all proof hashes
request_proof_hashes = wallet_protocol.RequestAllProofHashes()
yield OutboundMessage(
NodeType.FULL_NODE,
Message("request_all_proof_hashes", request_proof_hashes),
Delivery.RESPOND,
)
start_wait = time.time()
while time.time() - start_wait < timeout:
if self._shut_down:
return
if len(self.proof_hashes) > 0:
break
await asyncio.sleep(0.5)
if len(self.proof_hashes) == 0:
raise TimeoutError("Took too long to fetch proof hashes.")
# TODO(mariano): Validate weight
# - Request headers for a random subset
# - Verify those proofs
weight = self.wallet_state_manager.block_records[fork_point_hash].weight
header_validate_start_height = max(
fork_point_height, self.config["starting_height"] - 1
)
if fork_point_height == 0:
difficulty = self.constants["STARTING_DIFFICULTY"]
else:
fork_point_parent_hash = self.wallet_state_manager.block_records[
fork_point_hash
].prev_header_hash
fork_point_parent_weight = self.wallet_state_manager.block_records[
fork_point_parent_hash
]
difficulty = uint64(weight - fork_point_parent_weight)
for height in range(fork_point_height + 1, header_validate_start_height):
_, difficulty_change = self.proof_hashes[height]
if difficulty_change is not None:
difficulty = difficulty_change
weight += difficulty
block_record = BlockRecord(
self.header_hashes[height],
self.header_hashes[height - 1],
uint32(height),
weight,
[],
[],
)
res = await self.wallet_state_manager.receive_block(block_record, None)
# Download headers in batches, and verify them as they come in. We download a few batches ahead,
# in case there are delays. TODO(mariano): optimize sync by pipelining
for height in range(0, tip_height + 1):
self.potential_blocks_received[uint32(height)] = asyncio.Event()
last_request_time = float(0)
highest_height_requested = uint32(0)
request_made = False
sleep_interval = 10
for height_checkpoint in range(
header_validate_start_height + 1, tip_height + 1, 1
):
end_height = min(height_checkpoint + 1, tip_height + 1)
total_time_slept = 0
while True:
if self._shut_down:
return
if total_time_slept > timeout:
raise TimeoutError("Took too long to fetch blocks")
# Request batches that we don't have yet
for batch in range(0, self.config["num_sync_batches"]):
batch_start = uint32(height_checkpoint + batch)
batch_end = min(batch_start + 1, tip_height + 1)
if batch_start > tip_height:
# We have asked for all blocks
break
blocks_missing = any(
[
not (self.potential_blocks_received[uint32(h)]).is_set()
for h in range(batch_start, batch_end)
]
)
if (
time.time() - last_request_time > sleep_interval
and blocks_missing
) or (batch_end - 1) > highest_height_requested:
# If we are missing blocks in this batch, and we haven't made a request in a while,
# Make a request for this batch. Also, if we have never requested this batch, make
# the request
self.log.info(f"Requesting sync header {batch_start}")
if batch_end - 1 > highest_height_requested:
highest_height_requested = uint32(batch_end - 1)
request_made = True
request_header = wallet_protocol.RequestHeader(
batch_start, self.header_hashes[batch_start],
)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("request_header", request_header),
Delivery.RANDOM,
)
if request_made:
# Reset the timer for requests, so we don't overload other peers with requests
last_request_time = time.time()
request_made = False
awaitables = [
self.potential_blocks_received[uint32(height)].wait()
for height in range(height_checkpoint, end_height)
]
future = asyncio.gather(*awaitables, return_exceptions=True)
try:
await asyncio.wait_for(future, timeout=sleep_interval)
break
except concurrent.futures.TimeoutError:
try:
await future
except asyncio.CancelledError:
pass
total_time_slept += sleep_interval
self.log.info("Did not receive desired headers")
# Verifies this batch, which we are guaranteed to have (since we broke from the above loop)
for height in range(height_checkpoint, end_height):
hh = self.potential_header_hashes[height]
block_record, header_block = self.cached_blocks[hh]
res = await self.wallet_state_manager.receive_block(
block_record, header_block
)
if (
res == ReceiveBlockResult.INVALID_BLOCK
or res == ReceiveBlockResult.DISCONNECTED_BLOCK
):
raise RuntimeError(
f"Invalid block header {block_record.header_hash}"
)
self.log.info(
f"Finished sync process up to height {max(self.wallet_state_manager.height_to_hash.keys())}"
)
async def _block_finished(
self, block_record: BlockRecord, header_block: HeaderBlock
):
if self.sync_mode:
self.potential_blocks_received[uint32(block_record.height)].set()
self.potential_header_hashes[block_record.height] = block_record.header_hash
self.cached_blocks[block_record.header_hash] = (block_record, header_block)
return
# 1. If disconnected and close, get parent header and return
lca = self.wallet_state_manager.block_records[self.wallet_state_manager.lca]
if block_record.prev_header_hash in self.wallet_state_manager.block_records:
# We have completed a block that we can add to chain, so add it.
res = await self.wallet_state_manager.receive_block(
block_record, header_block
)
if res == ReceiveBlockResult.DISCONNECTED_BLOCK:
self.log.error("Attempted to add disconnected block")
return
elif res == ReceiveBlockResult.INVALID_BLOCK:
self.log.error("Attempted to add invalid block")
return
elif res == ReceiveBlockResult.ALREADY_HAVE_BLOCK:
return
else:
# If we have the next block available, add it
if block_record.header_hash in self.cached_blocks:
new_br, new_hb = self.cached_blocks[block_record.header_hash]
async for msg in self._block_finished(new_br, new_hb):
yield msg
if res == ReceiveBlockResult.ADDED_TO_HEAD:
# Removes outdated cached blocks if we're not syncing
if not self.sync_mode:
for header_hash in self.cached_blocks:
if (
block_record.height
- self.cached_blocks[header_hash][0].height
> 100
):
del self.cached_blocks[header_hash]
if header_hash in self.cached_additions:
del self.cached_additions[header_hash]
if header_hash in self.cached_removals:
del self.cached_removals[header_hash]
else:
if block_record.height - lca.height < self.short_sync_threshold:
# We have completed a block that is in the near future, so cache it, and fetch parent
self.cached_blocks[block_record.prev_header_hash] = (
block_record,
header_block,
)
header_request = wallet_protocol.RequestHeader(
uint32(block_record.height - 1), block_record.prev_header_hash,
)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("request_header", header_request),
Delivery.RESPOND,
)
return
self.log.warning("Block too far ahead in the future, should never get here")
return
@api_request
async def transaction_ack(self, ack: wallet_protocol.TransactionAck):
@ -79,31 +368,37 @@ class WalletNode:
async def respond_all_proof_hashes(
self, response: wallet_protocol.RespondAllProofHashes
):
# TODO(mariano): save proof hashes
pass
if not self.sync_mode:
self.log.warning("Receiving proof hashes while not syncing.")
return
self.proof_hashes = response.hashes
@api_request
async def respond_all_header_hashes_after(
self, response: wallet_protocol.RespondAllHeaderHashesAfter
):
# TODO(mariano): save header_hashes
pass
if not self.sync_mode:
self.log.warning("Receiving header hashes while not syncing.")
return
self.header_hashes = response.hashes
@api_request
async def reject_all_header_hashes_after_request(
self, response: wallet_protocol.RejectAllHeaderHashesAfterRequest
):
# TODO(mariano): retry
self.log.error("All header hashes after request rejected")
pass
@api_request
async def new_lca(self, request: wallet_protocol.NewLCA):
if self.sync_mode:
return
# If already seen LCA, ignore.
if request.lca_hash in self.wallet_state_manager.block_records:
return
lca = self.wallet_state_manager.block_records[self.wallet_state_manager.lca]
# If it's not the heaviest chain, ignore.
if request.weight < lca.weight:
return
@ -111,12 +406,14 @@ class WalletNode:
if int(request.height) - int(lca.height) > self.short_sync_threshold:
try:
# Performs sync, and catch exceptions so we don't close the connection
self.sync_mode = True
async for ret_msg in self._sync():
yield ret_msg
except asyncio.CancelledError:
self.log.error("Syncing failed, CancelledError")
except BaseException as e:
self.log.error(f"Error {type(e)}{e} with syncing")
self.sync_mode = False
else:
header_request = wallet_protocol.RequestHeader(
uint32(request.height - 1), request.prev_header_hash
@ -130,27 +427,23 @@ class WalletNode:
@api_request
async def respond_header(self, response: wallet_protocol.RespondHeader):
block = response.header_block
# 0. If we already have, return
# If we already have, return
if block.header_hash in self.wallet_state_manager.block_records:
return
lca = self.wallet_state_manager.block_records[self.wallet_state_manager.lca]
block_record = BlockRecord(
block.header_hash,
block.prev_header_hash,
block.height,
block.weight,
[],
[],
)
# 1. If disconnected and close, get parent header and return
if block.prev_header_hash not in self.wallet_state_manager.block_records:
if block.height - lca.height < self.short_sync_threshold:
header_request = wallet_protocol.RequestHeader(
uint32(block.height - 1), block.prev_header_hash,
)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("request_header", header_request),
Delivery.RESPOND,
)
return
# 2. If we have transactions, fetch adds/deletes
# If we have transactions, fetch adds/deletes
if response.transactions_filter is not None:
# Caches the block so we can finalize it when additions and removals arrive
self.cached_blocks[block.header_hash] = (block_record, block)
(
additions,
removals,
@ -178,90 +471,153 @@ class WalletNode:
Delivery.RESPOND,
)
else:
block_record = BlockRecord(block.header_hash, block.prev_header_hash, block.height, block.weight, [], [])
res = await self.wallet_state_manager.receive_block(block_record, block)
# 3. If we don't have, don't fetch
# 4. If we have the next header cached, process it
pass
# If we don't have any transactions in filter, don't fetch, and finish the block
async for msg in self._block_finished(block_record, block):
yield msg
@api_request
async def reject_header_request(
self, response: wallet_protocol.RejectHeaderRequest
):
# TODO(mariano): implement
pass
self.log.error("Header request rejected")
@api_request
async def respond_removals(self, response: wallet_protocol.RespondRemovals):
# TODO(mariano): implement
pass
if response.header_hash not in self.cached_blocks:
self.log.warning("Do not have header for removals")
return
block_record, header_block = self.cached_blocks[response.header_hash]
assert response.height == block_record.height
removals: List[bytes32]
if response.proofs is None:
# Find our removals
all_coins: List[Coin] = []
for coin_name, coin in response.coins:
if coin is not None:
all_coins.append(coin)
removals = [
c.name()
for c in await self.wallet_state_manager.get_relevant_removals(
all_coins
)
]
else:
removals = []
assert len(response.coins) == len(response.proofs)
for i in range(len(response.coins)):
# Coins are in the same order as proofs
assert response.coins[i][0] == response.proofs[i][0]
coin = response.coins[i][1]
if coin is None:
assert confirm_not_included_already_hashed(
header_block.header.data.removals_root,
response.coins[i][0],
response.proofs[i][1],
)
else:
assert response.coins[i][0] == coin.name
assert confirm_included_already_hashed(
header_block.header.data.removals_root,
coin.name(),
response.proofs[i][1],
)
removals.append(response.coins[i][0])
additions = self.cached_additions.get(response.header_hash, [])
new_br = BlockRecord(
block_record.header_hash,
block_record.prev_header_hash,
block_record.height,
block_record.weight,
additions,
removals,
)
self.cached_blocks[response.header_hash] = (new_br, header_block)
self.cached_removals[response.header_hash] = removals
if response.header_hash in self.cached_additions:
# We have collected all three things: header, additions, and removals. Can proceed.
# Otherwise, we wait for the additions to arrive
async for msg in self._block_finished(new_br, header_block):
yield msg
@api_request
async def reject_removals_request(
self, response: wallet_protocol.RejectRemovalsRequest
):
# TODO(mariano): implement
pass
self.log.error("Removals request rejected")
# @api_request
# async def received_body(self, response: wallet_protocol.RespondBody):
# """
# Called when body is received from the FullNode
# """
@api_request
async def respond_additions(self, response: wallet_protocol.RespondAdditions):
if response.header_hash not in self.cached_blocks:
self.log.warning("Do not have header for additions")
return
block_record, header_block = self.cached_blocks[response.header_hash]
assert response.height == block_record.height
# # Retry sending queued up transactions
# await self.retry_send_queue()
additions: List[Coin]
if response.proofs is None:
# Find our removals
all_coins: List[Coin] = []
for puzzle_hash, coin_list_0 in response.coins:
all_coins += coin_list_0
additions = await self.wallet_state_manager.get_relevant_additions(
all_coins
)
else:
additions = []
assert len(response.coins) == len(response.proofs)
for i in range(len(response.coins)):
assert response.coins[i][0] == response.proofs[i][0]
coin_list_1: List[Coin] = response.coins[i][1]
puzzle_hash_proof: bytes32 = response.proofs[i][1]
coin_list_proof: Optional[bytes32] = response.proofs[i][2]
if len(coin_list_1) == 0:
# Verify exclusion proof for puzzle hash
assert confirm_not_included_already_hashed(
header_block.header.data.additions_root,
response.coins[i][0],
puzzle_hash_proof,
)
else:
# Verify inclusion proof for puzzle hash
assert confirm_included_already_hashed(
header_block.header.data.additions_root,
response.coins[i][0],
puzzle_hash_proof,
)
# Verify inclusion proof for coin list
assert confirm_included_already_hashed(
header_block.header.data.additions_root,
hash_coin_list(coin_list_1),
coin_list_proof,
)
for coin in coin_list_1:
assert coin.puzzle_hash == response.coins[i][0]
additions += coin_list_1
removals = self.cached_removals.get(response.header_hash, [])
new_br = BlockRecord(
block_record.header_hash,
block_record.prev_header_hash,
block_record.height,
block_record.weight,
additions,
removals,
)
self.cached_blocks[response.header_hash] = (new_br, header_block)
self.cached_additions[response.header_hash] = additions
# additions: List[Coin] = []
if response.header_hash in self.cached_removals:
# We have collected all three things: header, additions, and removals. Can proceed.
# Otherwise, we wait for the removals to arrive
async for msg in self._block_finished(new_br, header_block):
yield msg
# if await self.wallet.can_generate_puzzle_hash(
# response.header.data.coinbase.puzzle_hash
# ):
# await self.wallet_state_manager.coin_added(
# response.header.data.coinbase, response.height, True
# )
# if await self.wallet.can_generate_puzzle_hash(
# response.header.data.fees_coin.puzzle_hash
# ):
# await self.wallet_state_manager.coin_added(
# response.header.data.fees_coin, response.height, True
# )
# npc_list: List[NPC]
# if response.transactions_generator:
# error, npc_list, cost = get_name_puzzle_conditions(
# response.transactions_generator
# )
# additions.extend(additions_for_npc(npc_list))
# for added_coin in additions:
# if await self.wallet.can_generate_puzzle_hash(added_coin.puzzle_hash):
# await self.wallet_state_manager.coin_added(
# added_coin, response.height, False
# )
# for npc in npc_list:
# if await self.wallet.can_generate_puzzle_hash(npc.puzzle_hash):
# await self.wallet_state_manager.coin_removed(
# npc.coin_name, response.height
# )
# async def retry_send_queue(self):
# records = await self.wallet_state_manager.get_send_queue()
# for record in records:
# if record.spend_bundle:
# await self._send_transaction(record.spend_bundle)
# async def _send_transaction(self, spend_bundle: SpendBundle):
# """ Sends spendbundle to connected full Nodes."""
# await self.wallet_state_manager.add_pending_transaction(spend_bundle)
# msg = OutboundMessage(
# NodeType.FULL_NODE,
# Message("wallet_transaction", spend_bundle),
# Delivery.BROADCAST,
# )
# if self.server:
# async for reply in self.server.push_message(msg):
# self.log.info(reply)
@api_request
async def reject_additions_request(
self, response: wallet_protocol.RejectAdditionsRequest
):
# TODO(mariano): implement
self.log.error("Additions request rejected")

View File

@ -3,7 +3,7 @@ from pathlib import Path
from typing import Dict, Optional, List, Set, Tuple
import logging
import asyncio
from chiabip158 import PyBIP158
from src.types.hashable.coin import Coin
@ -39,6 +39,9 @@ class WalletStateManager:
lca: Optional[bytes32]
start_index: int
# Makes sure only one asyncio thread is changing the blockchain state at one time
lock: asyncio.Lock
log: logging.Logger
# TODO Don't allow user to send tx until wallet is synced
@ -63,12 +66,14 @@ class WalletStateManager:
self.log = logging.getLogger(name)
else:
self.log = logging.getLogger(__name__)
self.lock = asyncio.Lock()
self.wallet_store = await WalletStore.create(db_path)
self.tx_store = await WalletTransactionStore.create(db_path)
self.puzzle_store = await WalletPuzzleStore.create(db_path)
self.lca = None
self.synced = False
self.height_to_hash = {}
self.block_records = await self.wallet_store.get_lca_path()
genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"])
@ -281,62 +286,96 @@ class WalletStateManager:
records = await self.tx_store.get_all_transactions()
return records
def find_fork_point(self, alternate_chain: List[bytes32]) -> uint32:
"""
Takes in an alternate blockchain (headers), and compares it to self. Returns the last header
where both blockchains are equal.
"""
lca: BlockRecord = self.block_records[self.lca]
if lca.height >= len(alternate_chain) - 1:
raise ValueError("Alternate chain is shorter")
low: uint32 = uint32(0)
high = lca.height
while low + 1 < high:
mid = uint32((low + high) // 2)
if self.height_to_hash[uint32(mid)] != alternate_chain[mid]:
high = mid
else:
low = mid
if low == high and low == 0:
assert self.height_to_hash[uint32(0)] == alternate_chain[0]
return uint32(0)
assert low + 1 == high
if self.height_to_hash[uint32(low)] == alternate_chain[low]:
if self.height_to_hash[uint32(high)] == alternate_chain[high]:
return high
else:
return low
elif low > 0:
assert self.height_to_hash[uint32(low - 1)] == alternate_chain[low - 1]
return uint32(low - 1)
else:
raise ValueError("Invalid genesis block")
async def receive_block(
self, block: BlockRecord, header_block: Optional[HeaderBlock] = None,
) -> ReceiveBlockResult:
if block.header_hash in self.block_records:
return ReceiveBlockResult.ALREADY_HAVE_BLOCK
async with self.lock:
if block.header_hash in self.block_records:
return ReceiveBlockResult.ALREADY_HAVE_BLOCK
if block.prev_header_hash not in self.block_records or block.height == 0:
return ReceiveBlockResult.DISCONNECTED_BLOCK
if block.prev_header_hash not in self.block_records and block.height != 0:
return ReceiveBlockResult.DISCONNECTED_BLOCK
if header_block is not None:
# TODO: validate header block
pass
if header_block is not None:
# TODO: validate header block
pass
self.block_records[block.header_hash] = block
await self.wallet_store.add_block_record(block, False)
self.block_records[block.header_hash] = block
await self.wallet_store.add_block_record(block, False)
# Genesis case
if self.lca is None:
assert block.height == 0
await self.wallet_store.add_block_to_path(block.header_hash)
self.lca = block.header_hash
for coin in block.additions:
await self.coin_added(coin, block.height, False)
for coin_name in block.removals:
await self.coin_removed(coin_name, block.height)
self.height_to_hash[uint32(0)] = block.header_hash
return ReceiveBlockResult.ADDED_TO_HEAD
# Genesis case
if self.lca is None:
assert block.height == 0
await self.wallet_store.add_block_to_path(block.header_hash)
self.lca = block.header_hash
for coin in block.additions:
await self.coin_added(coin, block.height, False)
for coin_name in block.removals:
await self.coin_removed(coin_name, block.height)
self.height_to_hash[uint32(0)] = block.header_hash
return ReceiveBlockResult.ADDED_TO_HEAD
# Not genesis, updated LCA
if block.weight > self.block_records[self.lca].weight:
# Not genesis, updated LCA
if block.weight > self.block_records[self.lca].weight:
fork_h = self.find_fork_for_lca(block)
await self.reorg_rollback(fork_h)
fork_h = self.find_fork_for_lca(block)
await self.reorg_rollback(fork_h)
# Add blocks between fork point and new lca
fork_hash = self.height_to_hash[fork_h]
blocks_to_add: List[BlockRecord] = []
tip_hash: bytes32 = block.header_hash
while True:
if tip_hash == fork_hash:
break
record = self.block_records[tip_hash]
blocks_to_add.append(record)
tip_hash = record.prev_header_hash
blocks_to_add.reverse()
# Add blocks between fork point and new lca
fork_hash = self.height_to_hash[fork_h]
blocks_to_add: List[BlockRecord] = []
tip_hash: bytes32 = block.header_hash
while True:
if tip_hash == fork_hash:
break
record = self.block_records[tip_hash]
blocks_to_add.append(record)
tip_hash = record.prev_header_hash
blocks_to_add.reverse()
for path_block in blocks_to_add:
self.height_to_hash[path_block.height] = path_block.header_hash
await self.wallet_store.add_block_to_path(path_block.header_hash)
for coin in path_block.additions:
await self.coin_added(coin, path_block.height, False)
for coin_name in path_block.removals:
await self.coin_removed(coin_name, path_block.height)
return ReceiveBlockResult.ADDED_TO_HEAD
for path_block in blocks_to_add:
self.height_to_hash[path_block.height] = path_block.header_hash
await self.wallet_store.add_block_to_path(path_block.header_hash)
for coin in path_block.additions:
await self.coin_added(coin, path_block.height, False)
for coin_name in path_block.removals:
await self.coin_removed(coin_name, path_block.height)
self.lca = block.header_hash
return ReceiveBlockResult.ADDED_TO_HEAD
return ReceiveBlockResult.ADDED_AS_ORPHAN
return ReceiveBlockResult.ADDED_AS_ORPHAN
def find_fork_for_lca(self, new_lca: BlockRecord) -> uint32:
""" Tries to find height where new chain (current) diverged from the old chain where old_lca was the LCA"""
@ -358,8 +397,7 @@ class WalletStateManager:
self, transactions_fitler: bytes
) -> Tuple[List[bytes32], List[bytes32]]:
""" Returns a list of our coin ids, and a list of puzzle_hashes that positively match with provided filter. """
tx_filter = PyBIP158(transactions_fitler)
tx_filter = PyBIP158([b for b in transactions_fitler])
my_coin_records: Set[
CoinRecord
] = await self.wallet_store.get_coin_records_by_spent(False)

View File

@ -16,7 +16,6 @@ class WalletStore:
db_connection: aiosqlite.Connection
# Whether or not we are syncing
sync_mode: bool = False
lock: asyncio.Lock
coin_record_cache: Dict[str, CoinRecord]
cache_size: uint32
@ -221,7 +220,6 @@ class WalletStore:
if br.height > max_height:
max_height = br.height
# Makes sure there's exactly one block per height
print(max_height, len(rows))
assert max_height == len(rows) - 1
return hash_to_br
@ -232,7 +230,7 @@ class WalletStore:
block_record.header_hash.hex(),
block_record.height,
in_lca_path,
block_record,
bytes(block_record),
),
)
await cursor.close()

View File

@ -6,7 +6,7 @@ import pytest
from src.types.peer_info import PeerInfo
from src.protocols import full_node_protocol
from src.util.ints import uint16
from tests.setup_nodes import setup_two_nodes, test_constants, bt
from tests.setup_nodes import setup_two_nodes, setup_node_and_wallet, test_constants, bt
@pytest.fixture(scope="module")
@ -21,6 +21,11 @@ class TestFullSync:
async for _ in setup_two_nodes():
yield _
@pytest.fixture(scope="function")
async def wallet_node(self):
async for _ in setup_node_and_wallet():
yield _
@pytest.mark.asyncio
async def test_basic_sync(self, two_nodes):
num_blocks = 100
@ -56,6 +61,43 @@ class TestFullSync:
f"Took too long to process blocks, stopped at: {time.time() - start_unf}"
)
@pytest.mark.asyncio
async def test_basic_sync_wallet(self, wallet_node):
num_blocks = 25
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
full_node_1, wallet_node, server_1, server_2 = wallet_node
for i in range(1, num_blocks):
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
await server_2.start_client(
PeerInfo(server_1._host, uint16(server_1._port)), None
)
await asyncio.sleep(2) # Allow connections to get made
start_unf = time.time()
while time.time() - start_unf < 60:
# The second node should eventually catch up to the first one, and have the
# same tip at height num_blocks - 1.
if (
wallet_node.wallet_state_manager.block_records[
wallet_node.wallet_state_manager.lca
].height
== num_blocks - 7
):
print(f"Time taken to sync {num_blocks} is {time.time() - start_unf}")
return
await asyncio.sleep(0.1)
raise Exception(
f"Took too long to process blocks, stopped at: {time.time() - start_unf}"
)
@pytest.mark.asyncio
async def test_short_sync(self, two_nodes):
num_blocks = 10

View File

@ -1,6 +1,8 @@
from typing import Any, Dict
from pathlib import Path
import asyncio
import blspy
from secrets import token_bytes
from src.full_node.blockchain import Blockchain
from src.full_node.mempool_manager import MempoolManager
@ -8,6 +10,7 @@ from src.full_node.store import FullNodeStore
from src.full_node.full_node import FullNode
from src.server.connection import NodeType
from src.server.server import ChiaServer
from src.wallet.wallet_node import WalletNode
from src.types.full_block import FullBlock
from src.full_node.coin_store import CoinStore
from tests.block_tools import BlockTools
@ -91,6 +94,33 @@ async def setup_full_node(db_name, port, introducer_port=None, dic={}):
Path(db_name).unlink()
async def setup_wallet_node(port, introducer_port=None, dic={}):
config = load_config("config.yaml", "wallet")
key_config = {
"wallet_sk": bytes(blspy.ExtendedPrivateKey.from_seed(b"1234")).hex(),
}
test_constants_copy = test_constants.copy()
for k in dic.keys():
test_constants_copy[k] = dic[k]
db_path = "test-wallet-db" + token_bytes(32).hex()
if Path(db_path).exists():
Path(db_path).unlink()
wallet = await WalletNode.create(
config, key_config, db_path=db_path, override_constants=test_constants_copy
)
server = ChiaServer(port, wallet, NodeType.WALLET)
wallet.set_server(server)
yield (wallet, server)
server.close_all()
await wallet.wallet_state_manager.clear_all_stores()
await wallet.wallet_state_manager.close_all_stores()
Path(db_path).unlink()
await server.await_closed()
async def setup_harvester(port, dic={}):
config = load_config("config.yaml", "harvester")
@ -195,6 +225,24 @@ async def setup_two_nodes(dic={}):
pass
async def setup_node_and_wallet(dic={}):
node_iters = [
setup_full_node("blockchain_test.db", 21234, dic=dic),
setup_wallet_node(21235, dic=dic),
]
full_node, s1 = await node_iters[0].__anext__()
wallet, s2 = await node_iters[1].__anext__()
yield (full_node, wallet, s1, s2)
for node_iter in node_iters:
try:
await node_iter.__anext__()
except StopAsyncIteration:
pass
async def setup_full_system(dic={}):
node_iters = [
setup_introducer(21233),