diff --git a/config/config.yaml b/config/config.yaml index d2f0b234..e18703c1 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,7 +4,7 @@ ping_interval: 120 # Controls logging of all servers (harvester, farmer, etc..). Each one can be overriden. logging: &logging - log_stdout: False # If True, outputs to stdout instead of a file + log_stdout: True # If True, outputs to stdout instead of a file log_filename: "chia.log" harvester: @@ -124,9 +124,14 @@ introducer: wallet: host: 127.0.0.1 - port: 8223 + port: 8449 rpc_port: 9256 + # The minimum height that we care about for our transactions. Set to zero + # If we are restoring from private key and don't know the height. + starting_height: 0 + num_sync_batches: 10 + full_node_peer: host: 127.0.0.1 port: 8444 diff --git a/src/full_node/full_node.py b/src/full_node/full_node.py index 34426860..998de070 100644 --- a/src/full_node/full_node.py +++ b/src/full_node/full_node.py @@ -1821,7 +1821,7 @@ class FullNode: ) yield OutboundMessage( NodeType.WALLET, - Message("respond_block", response), + Message("respond_header", response), Delivery.RESPOND, ) return diff --git a/src/server/start_wallet.py b/src/server/start_wallet.py index dae5460f..691e49f6 100644 --- a/src/server/start_wallet.py +++ b/src/server/start_wallet.py @@ -41,7 +41,7 @@ async def main(): asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all) asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all) - _ = await server.start_server(config["host"], wallet._on_connect) + _ = await server.start_server(config["host"], None) await asyncio.sleep(1) _ = await server.start_client(full_node_peer, None) diff --git a/src/wallet/wallet_node.py b/src/wallet/wallet_node.py index acdb45e6..ecedba89 100644 --- a/src/wallet/wallet_node.py +++ b/src/wallet/wallet_node.py @@ -1,17 +1,28 @@ from pathlib import Path import asyncio -from typing import Dict, Optional -from blspy import ExtendedPrivateKey +import time +from typing import Dict, Optional, Tuple, List +import concurrent import logging +from blspy import ExtendedPrivateKey +from src.util.merkle_set import ( + confirm_included_already_hashed, + confirm_not_included_already_hashed, +) from src.protocols import wallet_protocol from src.consensus.constants import constants as consensus_constants from src.server.server import ChiaServer from src.server.outbound_message import OutboundMessage, NodeType, Message, Delivery -from src.util.ints import uint32 +from src.util.ints import uint32, uint64 +from src.types.sized_bytes import bytes32 from src.util.api_decorators import api_request from src.wallet.wallet import Wallet from src.wallet.wallet_state_manager import WalletStateManager from src.wallet.block_record import BlockRecord +from src.types.header_block import HeaderBlock +from src.types.full_block import FullBlock +from src.types.hashable.coin import Coin, hash_coin_list +from src.full_node.blockchain import ReceiveBlockResult class WalletNode: @@ -22,12 +33,25 @@ class WalletNode: wallet_state_manager: WalletStateManager log: logging.Logger wallet: Wallet + cached_blocks: Dict[bytes32, Tuple[BlockRecord, HeaderBlock]] + cached_removals: Dict[bytes32, List[bytes32]] + cached_additions: Dict[bytes32, List[Coin]] + proof_hashes: List[Tuple[bytes32, Optional[uint64]]] + header_hashes: List[bytes32] + potential_blocks_received: Dict[uint32, asyncio.Event] + potential_header_hashes: Dict[uint32, bytes32] constants: Dict short_sync_threshold: int + sync_mode: bool + _shut_down: bool @staticmethod async def create( - config: Dict, key_config: Dict, name: str = None, override_constants: Dict = {} + config: Dict, + key_config: Dict, + name: str = None, + db_path=None, + override_constants: Dict = {}, ): self = WalletNode() self.config = config @@ -43,15 +67,31 @@ class WalletNode: self.log = logging.getLogger(__name__) pub_hex = self.private_key.get_public_key().serialize().hex() - path = Path(f"wallet_db_{pub_hex}.db") + if not db_path: + path = Path(f"wallet_db_{pub_hex}.db") + else: + path = db_path self.wallet_state_manager = await WalletStateManager.create( - config, path, override_constants=override_constants, + config, path, override_constants=self.constants, ) self.wallet = await Wallet.create(config, key_config, self.wallet_state_manager) - self.server = None + # Normal operation data + self.cached_blocks = {} + self.cached_removals = {} + self.cached_additions = {} + + # Sync data + self.sync_mode = False + self._shut_down = False + self.proof_hashes = [] + self.header_hashes = [] self.short_sync_threshold = 10 + self.potential_blocks_received = {} + self.potential_header_hashes = {} + + self.server = None return self @@ -59,13 +99,262 @@ class WalletNode: self.server = server self.wallet.set_server(server) + def _shutdown(self): + self._shut_down = True + async def _sync(self): """ Wallet has fallen far behind (or is starting up for the first time), and must be synced up to the tip of the blockchain """ - # TODO(mariano): implement - pass + # 1. Get all header hashes + self.header_hashes = [] + self.proof_hashes = [] + self.potential_header_hashes = {} + genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"]) + genesis_challenge = genesis.proof_of_space.challenge_hash + request_header_hashes = wallet_protocol.RequestAllHeaderHashesAfter( + uint32(0), genesis_challenge + ) + yield OutboundMessage( + NodeType.FULL_NODE, + Message("request_all_header_hashes_after", request_header_hashes), + Delivery.RESPOND, + ) + timeout = 100 + start_wait = time.time() + while time.time() - start_wait < timeout: + if self._shut_down: + return + if len(self.header_hashes) > 0: + break + await asyncio.sleep(0.5) + if len(self.header_hashes) == 0: + raise TimeoutError("Took too long to fetch header hashes.") + + # 2. Find fork point + fork_point_height: uint32 = self.wallet_state_manager.find_fork_point( + self.header_hashes + ) + fork_point_hash: bytes32 = self.header_hashes[fork_point_height] + self.log.info(f"Fork point: {fork_point_hash} at height {fork_point_height}") + tip_height = ( + len(self.header_hashes) - 5 + if len(self.header_hashes) > 5 + else len(self.header_hashes) + ) + + header_validate_start_height: uint32 + if self.config["starting_height"] == 0: + header_validate_start_height = fork_point_height + else: + # Request all proof hashes + request_proof_hashes = wallet_protocol.RequestAllProofHashes() + yield OutboundMessage( + NodeType.FULL_NODE, + Message("request_all_proof_hashes", request_proof_hashes), + Delivery.RESPOND, + ) + start_wait = time.time() + while time.time() - start_wait < timeout: + if self._shut_down: + return + if len(self.proof_hashes) > 0: + break + await asyncio.sleep(0.5) + if len(self.proof_hashes) == 0: + raise TimeoutError("Took too long to fetch proof hashes.") + + # TODO(mariano): Validate weight + # - Request headers for a random subset + # - Verify those proofs + + weight = self.wallet_state_manager.block_records[fork_point_hash].weight + header_validate_start_height = max( + fork_point_height, self.config["starting_height"] - 1 + ) + if fork_point_height == 0: + difficulty = self.constants["STARTING_DIFFICULTY"] + else: + fork_point_parent_hash = self.wallet_state_manager.block_records[ + fork_point_hash + ].prev_header_hash + fork_point_parent_weight = self.wallet_state_manager.block_records[ + fork_point_parent_hash + ] + difficulty = uint64(weight - fork_point_parent_weight) + for height in range(fork_point_height + 1, header_validate_start_height): + _, difficulty_change = self.proof_hashes[height] + if difficulty_change is not None: + difficulty = difficulty_change + weight += difficulty + block_record = BlockRecord( + self.header_hashes[height], + self.header_hashes[height - 1], + uint32(height), + weight, + [], + [], + ) + res = await self.wallet_state_manager.receive_block(block_record, None) + + # Download headers in batches, and verify them as they come in. We download a few batches ahead, + # in case there are delays. TODO(mariano): optimize sync by pipelining + for height in range(0, tip_height + 1): + self.potential_blocks_received[uint32(height)] = asyncio.Event() + last_request_time = float(0) + highest_height_requested = uint32(0) + request_made = False + sleep_interval = 10 + + for height_checkpoint in range( + header_validate_start_height + 1, tip_height + 1, 1 + ): + end_height = min(height_checkpoint + 1, tip_height + 1) + + total_time_slept = 0 + while True: + if self._shut_down: + return + if total_time_slept > timeout: + raise TimeoutError("Took too long to fetch blocks") + + # Request batches that we don't have yet + for batch in range(0, self.config["num_sync_batches"]): + batch_start = uint32(height_checkpoint + batch) + batch_end = min(batch_start + 1, tip_height + 1) + + if batch_start > tip_height: + # We have asked for all blocks + break + + blocks_missing = any( + [ + not (self.potential_blocks_received[uint32(h)]).is_set() + for h in range(batch_start, batch_end) + ] + ) + if ( + time.time() - last_request_time > sleep_interval + and blocks_missing + ) or (batch_end - 1) > highest_height_requested: + # If we are missing blocks in this batch, and we haven't made a request in a while, + # Make a request for this batch. Also, if we have never requested this batch, make + # the request + self.log.info(f"Requesting sync header {batch_start}") + if batch_end - 1 > highest_height_requested: + highest_height_requested = uint32(batch_end - 1) + request_made = True + request_header = wallet_protocol.RequestHeader( + batch_start, self.header_hashes[batch_start], + ) + yield OutboundMessage( + NodeType.FULL_NODE, + Message("request_header", request_header), + Delivery.RANDOM, + ) + if request_made: + # Reset the timer for requests, so we don't overload other peers with requests + last_request_time = time.time() + request_made = False + + awaitables = [ + self.potential_blocks_received[uint32(height)].wait() + for height in range(height_checkpoint, end_height) + ] + future = asyncio.gather(*awaitables, return_exceptions=True) + try: + await asyncio.wait_for(future, timeout=sleep_interval) + break + except concurrent.futures.TimeoutError: + try: + await future + except asyncio.CancelledError: + pass + total_time_slept += sleep_interval + self.log.info("Did not receive desired headers") + + # Verifies this batch, which we are guaranteed to have (since we broke from the above loop) + for height in range(height_checkpoint, end_height): + hh = self.potential_header_hashes[height] + block_record, header_block = self.cached_blocks[hh] + + res = await self.wallet_state_manager.receive_block( + block_record, header_block + ) + if ( + res == ReceiveBlockResult.INVALID_BLOCK + or res == ReceiveBlockResult.DISCONNECTED_BLOCK + ): + raise RuntimeError( + f"Invalid block header {block_record.header_hash}" + ) + self.log.info( + f"Finished sync process up to height {max(self.wallet_state_manager.height_to_hash.keys())}" + ) + + async def _block_finished( + self, block_record: BlockRecord, header_block: HeaderBlock + ): + if self.sync_mode: + self.potential_blocks_received[uint32(block_record.height)].set() + self.potential_header_hashes[block_record.height] = block_record.header_hash + self.cached_blocks[block_record.header_hash] = (block_record, header_block) + return + # 1. If disconnected and close, get parent header and return + lca = self.wallet_state_manager.block_records[self.wallet_state_manager.lca] + if block_record.prev_header_hash in self.wallet_state_manager.block_records: + # We have completed a block that we can add to chain, so add it. + res = await self.wallet_state_manager.receive_block( + block_record, header_block + ) + if res == ReceiveBlockResult.DISCONNECTED_BLOCK: + self.log.error("Attempted to add disconnected block") + return + elif res == ReceiveBlockResult.INVALID_BLOCK: + self.log.error("Attempted to add invalid block") + return + elif res == ReceiveBlockResult.ALREADY_HAVE_BLOCK: + return + else: + # If we have the next block available, add it + if block_record.header_hash in self.cached_blocks: + new_br, new_hb = self.cached_blocks[block_record.header_hash] + async for msg in self._block_finished(new_br, new_hb): + yield msg + if res == ReceiveBlockResult.ADDED_TO_HEAD: + # Removes outdated cached blocks if we're not syncing + if not self.sync_mode: + for header_hash in self.cached_blocks: + if ( + block_record.height + - self.cached_blocks[header_hash][0].height + > 100 + ): + del self.cached_blocks[header_hash] + if header_hash in self.cached_additions: + del self.cached_additions[header_hash] + if header_hash in self.cached_removals: + del self.cached_removals[header_hash] + else: + if block_record.height - lca.height < self.short_sync_threshold: + # We have completed a block that is in the near future, so cache it, and fetch parent + self.cached_blocks[block_record.prev_header_hash] = ( + block_record, + header_block, + ) + + header_request = wallet_protocol.RequestHeader( + uint32(block_record.height - 1), block_record.prev_header_hash, + ) + yield OutboundMessage( + NodeType.FULL_NODE, + Message("request_header", header_request), + Delivery.RESPOND, + ) + return + self.log.warning("Block too far ahead in the future, should never get here") + return @api_request async def transaction_ack(self, ack: wallet_protocol.TransactionAck): @@ -79,31 +368,37 @@ class WalletNode: async def respond_all_proof_hashes( self, response: wallet_protocol.RespondAllProofHashes ): - # TODO(mariano): save proof hashes - pass + if not self.sync_mode: + self.log.warning("Receiving proof hashes while not syncing.") + return + self.proof_hashes = response.hashes @api_request async def respond_all_header_hashes_after( self, response: wallet_protocol.RespondAllHeaderHashesAfter ): - # TODO(mariano): save header_hashes - pass + if not self.sync_mode: + self.log.warning("Receiving header hashes while not syncing.") + return + self.header_hashes = response.hashes @api_request async def reject_all_header_hashes_after_request( self, response: wallet_protocol.RejectAllHeaderHashesAfterRequest ): # TODO(mariano): retry + self.log.error("All header hashes after request rejected") pass @api_request async def new_lca(self, request: wallet_protocol.NewLCA): + if self.sync_mode: + return # If already seen LCA, ignore. if request.lca_hash in self.wallet_state_manager.block_records: return lca = self.wallet_state_manager.block_records[self.wallet_state_manager.lca] - # If it's not the heaviest chain, ignore. if request.weight < lca.weight: return @@ -111,12 +406,14 @@ class WalletNode: if int(request.height) - int(lca.height) > self.short_sync_threshold: try: # Performs sync, and catch exceptions so we don't close the connection + self.sync_mode = True async for ret_msg in self._sync(): yield ret_msg except asyncio.CancelledError: self.log.error("Syncing failed, CancelledError") except BaseException as e: self.log.error(f"Error {type(e)}{e} with syncing") + self.sync_mode = False else: header_request = wallet_protocol.RequestHeader( uint32(request.height - 1), request.prev_header_hash @@ -130,27 +427,23 @@ class WalletNode: @api_request async def respond_header(self, response: wallet_protocol.RespondHeader): block = response.header_block - # 0. If we already have, return + # If we already have, return if block.header_hash in self.wallet_state_manager.block_records: return - lca = self.wallet_state_manager.block_records[self.wallet_state_manager.lca] + block_record = BlockRecord( + block.header_hash, + block.prev_header_hash, + block.height, + block.weight, + [], + [], + ) - # 1. If disconnected and close, get parent header and return - if block.prev_header_hash not in self.wallet_state_manager.block_records: - if block.height - lca.height < self.short_sync_threshold: - header_request = wallet_protocol.RequestHeader( - uint32(block.height - 1), block.prev_header_hash, - ) - yield OutboundMessage( - NodeType.FULL_NODE, - Message("request_header", header_request), - Delivery.RESPOND, - ) - return - - # 2. If we have transactions, fetch adds/deletes + # If we have transactions, fetch adds/deletes if response.transactions_filter is not None: + # Caches the block so we can finalize it when additions and removals arrive + self.cached_blocks[block.header_hash] = (block_record, block) ( additions, removals, @@ -178,90 +471,153 @@ class WalletNode: Delivery.RESPOND, ) else: - block_record = BlockRecord(block.header_hash, block.prev_header_hash, block.height, block.weight, [], []) - res = await self.wallet_state_manager.receive_block(block_record, block) - # 3. If we don't have, don't fetch - # 4. If we have the next header cached, process it - pass + # If we don't have any transactions in filter, don't fetch, and finish the block + async for msg in self._block_finished(block_record, block): + yield msg @api_request async def reject_header_request( self, response: wallet_protocol.RejectHeaderRequest ): # TODO(mariano): implement - pass + self.log.error("Header request rejected") @api_request async def respond_removals(self, response: wallet_protocol.RespondRemovals): - # TODO(mariano): implement - pass + if response.header_hash not in self.cached_blocks: + self.log.warning("Do not have header for removals") + return + block_record, header_block = self.cached_blocks[response.header_hash] + assert response.height == block_record.height + + removals: List[bytes32] + if response.proofs is None: + # Find our removals + all_coins: List[Coin] = [] + for coin_name, coin in response.coins: + if coin is not None: + all_coins.append(coin) + removals = [ + c.name() + for c in await self.wallet_state_manager.get_relevant_removals( + all_coins + ) + ] + else: + removals = [] + assert len(response.coins) == len(response.proofs) + for i in range(len(response.coins)): + # Coins are in the same order as proofs + assert response.coins[i][0] == response.proofs[i][0] + coin = response.coins[i][1] + if coin is None: + assert confirm_not_included_already_hashed( + header_block.header.data.removals_root, + response.coins[i][0], + response.proofs[i][1], + ) + else: + assert response.coins[i][0] == coin.name + assert confirm_included_already_hashed( + header_block.header.data.removals_root, + coin.name(), + response.proofs[i][1], + ) + removals.append(response.coins[i][0]) + additions = self.cached_additions.get(response.header_hash, []) + new_br = BlockRecord( + block_record.header_hash, + block_record.prev_header_hash, + block_record.height, + block_record.weight, + additions, + removals, + ) + self.cached_blocks[response.header_hash] = (new_br, header_block) + self.cached_removals[response.header_hash] = removals + + if response.header_hash in self.cached_additions: + # We have collected all three things: header, additions, and removals. Can proceed. + # Otherwise, we wait for the additions to arrive + async for msg in self._block_finished(new_br, header_block): + yield msg @api_request async def reject_removals_request( self, response: wallet_protocol.RejectRemovalsRequest ): # TODO(mariano): implement - pass + self.log.error("Removals request rejected") - # @api_request - # async def received_body(self, response: wallet_protocol.RespondBody): - # """ - # Called when body is received from the FullNode - # """ + @api_request + async def respond_additions(self, response: wallet_protocol.RespondAdditions): + if response.header_hash not in self.cached_blocks: + self.log.warning("Do not have header for additions") + return + block_record, header_block = self.cached_blocks[response.header_hash] + assert response.height == block_record.height - # # Retry sending queued up transactions - # await self.retry_send_queue() + additions: List[Coin] + if response.proofs is None: + # Find our removals + all_coins: List[Coin] = [] + for puzzle_hash, coin_list_0 in response.coins: + all_coins += coin_list_0 + additions = await self.wallet_state_manager.get_relevant_additions( + all_coins + ) + else: + additions = [] + assert len(response.coins) == len(response.proofs) + for i in range(len(response.coins)): + assert response.coins[i][0] == response.proofs[i][0] + coin_list_1: List[Coin] = response.coins[i][1] + puzzle_hash_proof: bytes32 = response.proofs[i][1] + coin_list_proof: Optional[bytes32] = response.proofs[i][2] + if len(coin_list_1) == 0: + # Verify exclusion proof for puzzle hash + assert confirm_not_included_already_hashed( + header_block.header.data.additions_root, + response.coins[i][0], + puzzle_hash_proof, + ) + else: + # Verify inclusion proof for puzzle hash + assert confirm_included_already_hashed( + header_block.header.data.additions_root, + response.coins[i][0], + puzzle_hash_proof, + ) + # Verify inclusion proof for coin list + assert confirm_included_already_hashed( + header_block.header.data.additions_root, + hash_coin_list(coin_list_1), + coin_list_proof, + ) + for coin in coin_list_1: + assert coin.puzzle_hash == response.coins[i][0] + additions += coin_list_1 + removals = self.cached_removals.get(response.header_hash, []) + new_br = BlockRecord( + block_record.header_hash, + block_record.prev_header_hash, + block_record.height, + block_record.weight, + additions, + removals, + ) + self.cached_blocks[response.header_hash] = (new_br, header_block) + self.cached_additions[response.header_hash] = additions - # additions: List[Coin] = [] + if response.header_hash in self.cached_removals: + # We have collected all three things: header, additions, and removals. Can proceed. + # Otherwise, we wait for the removals to arrive + async for msg in self._block_finished(new_br, header_block): + yield msg - # if await self.wallet.can_generate_puzzle_hash( - # response.header.data.coinbase.puzzle_hash - # ): - # await self.wallet_state_manager.coin_added( - # response.header.data.coinbase, response.height, True - # ) - # if await self.wallet.can_generate_puzzle_hash( - # response.header.data.fees_coin.puzzle_hash - # ): - # await self.wallet_state_manager.coin_added( - # response.header.data.fees_coin, response.height, True - # ) - - # npc_list: List[NPC] - # if response.transactions_generator: - # error, npc_list, cost = get_name_puzzle_conditions( - # response.transactions_generator - # ) - - # additions.extend(additions_for_npc(npc_list)) - - # for added_coin in additions: - # if await self.wallet.can_generate_puzzle_hash(added_coin.puzzle_hash): - # await self.wallet_state_manager.coin_added( - # added_coin, response.height, False - # ) - - # for npc in npc_list: - # if await self.wallet.can_generate_puzzle_hash(npc.puzzle_hash): - # await self.wallet_state_manager.coin_removed( - # npc.coin_name, response.height - # ) - - # async def retry_send_queue(self): - # records = await self.wallet_state_manager.get_send_queue() - # for record in records: - # if record.spend_bundle: - # await self._send_transaction(record.spend_bundle) - - # async def _send_transaction(self, spend_bundle: SpendBundle): - # """ Sends spendbundle to connected full Nodes.""" - # await self.wallet_state_manager.add_pending_transaction(spend_bundle) - - # msg = OutboundMessage( - # NodeType.FULL_NODE, - # Message("wallet_transaction", spend_bundle), - # Delivery.BROADCAST, - # ) - # if self.server: - # async for reply in self.server.push_message(msg): - # self.log.info(reply) + @api_request + async def reject_additions_request( + self, response: wallet_protocol.RejectAdditionsRequest + ): + # TODO(mariano): implement + self.log.error("Additions request rejected") diff --git a/src/wallet/wallet_state_manager.py b/src/wallet/wallet_state_manager.py index a630b554..a61910b1 100644 --- a/src/wallet/wallet_state_manager.py +++ b/src/wallet/wallet_state_manager.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Dict, Optional, List, Set, Tuple import logging - +import asyncio from chiabip158 import PyBIP158 from src.types.hashable.coin import Coin @@ -39,6 +39,9 @@ class WalletStateManager: lca: Optional[bytes32] start_index: int + # Makes sure only one asyncio thread is changing the blockchain state at one time + lock: asyncio.Lock + log: logging.Logger # TODO Don't allow user to send tx until wallet is synced @@ -63,12 +66,14 @@ class WalletStateManager: self.log = logging.getLogger(name) else: self.log = logging.getLogger(__name__) + self.lock = asyncio.Lock() self.wallet_store = await WalletStore.create(db_path) self.tx_store = await WalletTransactionStore.create(db_path) self.puzzle_store = await WalletPuzzleStore.create(db_path) - + self.lca = None self.synced = False + self.height_to_hash = {} self.block_records = await self.wallet_store.get_lca_path() genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"]) @@ -281,62 +286,96 @@ class WalletStateManager: records = await self.tx_store.get_all_transactions() return records + def find_fork_point(self, alternate_chain: List[bytes32]) -> uint32: + """ + Takes in an alternate blockchain (headers), and compares it to self. Returns the last header + where both blockchains are equal. + """ + lca: BlockRecord = self.block_records[self.lca] + + if lca.height >= len(alternate_chain) - 1: + raise ValueError("Alternate chain is shorter") + low: uint32 = uint32(0) + high = lca.height + while low + 1 < high: + mid = uint32((low + high) // 2) + if self.height_to_hash[uint32(mid)] != alternate_chain[mid]: + high = mid + else: + low = mid + if low == high and low == 0: + assert self.height_to_hash[uint32(0)] == alternate_chain[0] + return uint32(0) + assert low + 1 == high + if self.height_to_hash[uint32(low)] == alternate_chain[low]: + if self.height_to_hash[uint32(high)] == alternate_chain[high]: + return high + else: + return low + elif low > 0: + assert self.height_to_hash[uint32(low - 1)] == alternate_chain[low - 1] + return uint32(low - 1) + else: + raise ValueError("Invalid genesis block") + async def receive_block( self, block: BlockRecord, header_block: Optional[HeaderBlock] = None, ) -> ReceiveBlockResult: - if block.header_hash in self.block_records: - return ReceiveBlockResult.ALREADY_HAVE_BLOCK + async with self.lock: + if block.header_hash in self.block_records: + return ReceiveBlockResult.ALREADY_HAVE_BLOCK - if block.prev_header_hash not in self.block_records or block.height == 0: - return ReceiveBlockResult.DISCONNECTED_BLOCK + if block.prev_header_hash not in self.block_records and block.height != 0: + return ReceiveBlockResult.DISCONNECTED_BLOCK - if header_block is not None: - # TODO: validate header block - pass + if header_block is not None: + # TODO: validate header block + pass - self.block_records[block.header_hash] = block - await self.wallet_store.add_block_record(block, False) + self.block_records[block.header_hash] = block + await self.wallet_store.add_block_record(block, False) - # Genesis case - if self.lca is None: - assert block.height == 0 - await self.wallet_store.add_block_to_path(block.header_hash) - self.lca = block.header_hash - for coin in block.additions: - await self.coin_added(coin, block.height, False) - for coin_name in block.removals: - await self.coin_removed(coin_name, block.height) - self.height_to_hash[uint32(0)] = block.header_hash - return ReceiveBlockResult.ADDED_TO_HEAD + # Genesis case + if self.lca is None: + assert block.height == 0 + await self.wallet_store.add_block_to_path(block.header_hash) + self.lca = block.header_hash + for coin in block.additions: + await self.coin_added(coin, block.height, False) + for coin_name in block.removals: + await self.coin_removed(coin_name, block.height) + self.height_to_hash[uint32(0)] = block.header_hash + return ReceiveBlockResult.ADDED_TO_HEAD - # Not genesis, updated LCA - if block.weight > self.block_records[self.lca].weight: + # Not genesis, updated LCA + if block.weight > self.block_records[self.lca].weight: - fork_h = self.find_fork_for_lca(block) - await self.reorg_rollback(fork_h) + fork_h = self.find_fork_for_lca(block) + await self.reorg_rollback(fork_h) - # Add blocks between fork point and new lca - fork_hash = self.height_to_hash[fork_h] - blocks_to_add: List[BlockRecord] = [] - tip_hash: bytes32 = block.header_hash - while True: - if tip_hash == fork_hash: - break - record = self.block_records[tip_hash] - blocks_to_add.append(record) - tip_hash = record.prev_header_hash - blocks_to_add.reverse() + # Add blocks between fork point and new lca + fork_hash = self.height_to_hash[fork_h] + blocks_to_add: List[BlockRecord] = [] + tip_hash: bytes32 = block.header_hash + while True: + if tip_hash == fork_hash: + break + record = self.block_records[tip_hash] + blocks_to_add.append(record) + tip_hash = record.prev_header_hash + blocks_to_add.reverse() - for path_block in blocks_to_add: - self.height_to_hash[path_block.height] = path_block.header_hash - await self.wallet_store.add_block_to_path(path_block.header_hash) - for coin in path_block.additions: - await self.coin_added(coin, path_block.height, False) - for coin_name in path_block.removals: - await self.coin_removed(coin_name, path_block.height) - return ReceiveBlockResult.ADDED_TO_HEAD + for path_block in blocks_to_add: + self.height_to_hash[path_block.height] = path_block.header_hash + await self.wallet_store.add_block_to_path(path_block.header_hash) + for coin in path_block.additions: + await self.coin_added(coin, path_block.height, False) + for coin_name in path_block.removals: + await self.coin_removed(coin_name, path_block.height) + self.lca = block.header_hash + return ReceiveBlockResult.ADDED_TO_HEAD - return ReceiveBlockResult.ADDED_AS_ORPHAN + return ReceiveBlockResult.ADDED_AS_ORPHAN def find_fork_for_lca(self, new_lca: BlockRecord) -> uint32: """ Tries to find height where new chain (current) diverged from the old chain where old_lca was the LCA""" @@ -358,8 +397,7 @@ class WalletStateManager: self, transactions_fitler: bytes ) -> Tuple[List[bytes32], List[bytes32]]: """ Returns a list of our coin ids, and a list of puzzle_hashes that positively match with provided filter. """ - - tx_filter = PyBIP158(transactions_fitler) + tx_filter = PyBIP158([b for b in transactions_fitler]) my_coin_records: Set[ CoinRecord ] = await self.wallet_store.get_coin_records_by_spent(False) diff --git a/src/wallet/wallet_store.py b/src/wallet/wallet_store.py index 8aa9b404..49363308 100644 --- a/src/wallet/wallet_store.py +++ b/src/wallet/wallet_store.py @@ -16,7 +16,6 @@ class WalletStore: db_connection: aiosqlite.Connection # Whether or not we are syncing - sync_mode: bool = False lock: asyncio.Lock coin_record_cache: Dict[str, CoinRecord] cache_size: uint32 @@ -221,7 +220,6 @@ class WalletStore: if br.height > max_height: max_height = br.height # Makes sure there's exactly one block per height - print(max_height, len(rows)) assert max_height == len(rows) - 1 return hash_to_br @@ -232,7 +230,7 @@ class WalletStore: block_record.header_hash.hex(), block_record.height, in_lca_path, - block_record, + bytes(block_record), ), ) await cursor.close() diff --git a/tests/full_node/test_full_sync.py b/tests/full_node/test_full_sync.py index 9983930e..bfb1cc26 100644 --- a/tests/full_node/test_full_sync.py +++ b/tests/full_node/test_full_sync.py @@ -6,7 +6,7 @@ import pytest from src.types.peer_info import PeerInfo from src.protocols import full_node_protocol from src.util.ints import uint16 -from tests.setup_nodes import setup_two_nodes, test_constants, bt +from tests.setup_nodes import setup_two_nodes, setup_node_and_wallet, test_constants, bt @pytest.fixture(scope="module") @@ -21,6 +21,11 @@ class TestFullSync: async for _ in setup_two_nodes(): yield _ + @pytest.fixture(scope="function") + async def wallet_node(self): + async for _ in setup_node_and_wallet(): + yield _ + @pytest.mark.asyncio async def test_basic_sync(self, two_nodes): num_blocks = 100 @@ -56,6 +61,43 @@ class TestFullSync: f"Took too long to process blocks, stopped at: {time.time() - start_unf}" ) + @pytest.mark.asyncio + async def test_basic_sync_wallet(self, wallet_node): + num_blocks = 25 + blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10) + full_node_1, wallet_node, server_1, server_2 = wallet_node + + for i in range(1, num_blocks): + async for _ in full_node_1.respond_block( + full_node_protocol.RespondBlock(blocks[i]) + ): + pass + + await server_2.start_client( + PeerInfo(server_1._host, uint16(server_1._port)), None + ) + + await asyncio.sleep(2) # Allow connections to get made + start_unf = time.time() + + while time.time() - start_unf < 60: + # The second node should eventually catch up to the first one, and have the + # same tip at height num_blocks - 1. + if ( + wallet_node.wallet_state_manager.block_records[ + wallet_node.wallet_state_manager.lca + ].height + == num_blocks - 7 + ): + print(f"Time taken to sync {num_blocks} is {time.time() - start_unf}") + + return + await asyncio.sleep(0.1) + + raise Exception( + f"Took too long to process blocks, stopped at: {time.time() - start_unf}" + ) + @pytest.mark.asyncio async def test_short_sync(self, two_nodes): num_blocks = 10 diff --git a/tests/setup_nodes.py b/tests/setup_nodes.py index 38c39dfa..25a10745 100644 --- a/tests/setup_nodes.py +++ b/tests/setup_nodes.py @@ -1,6 +1,8 @@ from typing import Any, Dict from pathlib import Path import asyncio +import blspy +from secrets import token_bytes from src.full_node.blockchain import Blockchain from src.full_node.mempool_manager import MempoolManager @@ -8,6 +10,7 @@ from src.full_node.store import FullNodeStore from src.full_node.full_node import FullNode from src.server.connection import NodeType from src.server.server import ChiaServer +from src.wallet.wallet_node import WalletNode from src.types.full_block import FullBlock from src.full_node.coin_store import CoinStore from tests.block_tools import BlockTools @@ -91,6 +94,33 @@ async def setup_full_node(db_name, port, introducer_port=None, dic={}): Path(db_name).unlink() +async def setup_wallet_node(port, introducer_port=None, dic={}): + config = load_config("config.yaml", "wallet") + key_config = { + "wallet_sk": bytes(blspy.ExtendedPrivateKey.from_seed(b"1234")).hex(), + } + test_constants_copy = test_constants.copy() + for k in dic.keys(): + test_constants_copy[k] = dic[k] + db_path = "test-wallet-db" + token_bytes(32).hex() + if Path(db_path).exists(): + Path(db_path).unlink() + + wallet = await WalletNode.create( + config, key_config, db_path=db_path, override_constants=test_constants_copy + ) + server = ChiaServer(port, wallet, NodeType.WALLET) + wallet.set_server(server) + + yield (wallet, server) + + server.close_all() + await wallet.wallet_state_manager.clear_all_stores() + await wallet.wallet_state_manager.close_all_stores() + Path(db_path).unlink() + await server.await_closed() + + async def setup_harvester(port, dic={}): config = load_config("config.yaml", "harvester") @@ -195,6 +225,24 @@ async def setup_two_nodes(dic={}): pass +async def setup_node_and_wallet(dic={}): + node_iters = [ + setup_full_node("blockchain_test.db", 21234, dic=dic), + setup_wallet_node(21235, dic=dic), + ] + + full_node, s1 = await node_iters[0].__anext__() + wallet, s2 = await node_iters[1].__anext__() + + yield (full_node, wallet, s1, s2) + + for node_iter in node_iters: + try: + await node_iter.__anext__() + except StopAsyncIteration: + pass + + async def setup_full_system(dic={}): node_iters = [ setup_introducer(21233),