diff --git a/src/full_node/block_store.py b/src/full_node/block_store.py new file mode 100644 index 00000000..b13fce52 --- /dev/null +++ b/src/full_node/block_store.py @@ -0,0 +1,134 @@ +import asyncio +import logging +import aiosqlite +from typing import Dict, List, Optional, Tuple + +from src.types.program import Program +from src.types.full_block import FullBlock +from src.types.header import HeaderData, Header +from src.types.header_block import HeaderBlock +from src.types.proof_of_space import ProofOfSpace +from src.types.sized_bytes import bytes32 +from src.util.hash import std_hash +from src.util.ints import uint32, uint64 + +log = logging.getLogger(__name__) + + +class BlockStore: + db: aiosqlite.Connection + + @classmethod + async def create(cls, connection): + self = cls() + + # All full blocks which have been added to the blockchain. Header_hash -> block + self.db = connection + await self.db.execute( + "CREATE TABLE IF NOT EXISTS blocks(height bigint, header_hash text PRIMARY KEY, block blob)" + ) + + # Headers + await self.db.execute( + "CREATE TABLE IF NOT EXISTS headers(height bigint, header_hash " + "text PRIMARY KEY, proof_hash text, header blob)" + ) + + # LCA + await self.db.execute( + "CREATE TABLE IF NOT EXISTS lca(header_hash text PRIMARY KEY)" + ) + + # Height index so we can look up in order of height for sync purposes + await self.db.execute( + "CREATE INDEX IF NOT EXISTS block_height on blocks(height)" + ) + await self.db.execute( + "CREATE INDEX IF NOT EXISTS header_height on headers(height)" + ) + + await self.db.commit() + + return self + + async def _clear_database(self): + async with self.lock: + await self.db.execute("DELETE FROM blocks") + await self.db.execute("DELETE FROM headers") + await self.db.commit() + + async def get_lca(self) -> Optional[bytes32]: + cursor = await self.db.execute("SELECT * from lca") + row = await cursor.fetchone() + await cursor.close() + if row is not None: + return bytes32(bytes.fromhex(row[0])) + return None + + async def set_lca(self, header_hash: bytes32) -> None: + await self.db.execute("DELETE FROM lca") + cursor_1 = await self.db.execute( + "INSERT OR REPLACE INTO lca VALUES(?)", (header_hash.hex(),) + ) + await cursor_1.close() + await self.db.commit() + + async def add_block(self, block: FullBlock) -> None: + assert block.proof_of_time is not None + cursor_1 = await self.db.execute( + "INSERT OR REPLACE INTO blocks VALUES(?, ?, ?)", + (block.height, block.header_hash.hex(), bytes(block)), + ) + await cursor_1.close() + proof_hash = std_hash( + block.proof_of_space.get_hash() + block.proof_of_time.output.get_hash() + ) + cursor_2 = await self.db.execute( + ("INSERT OR REPLACE INTO headers VALUES(?, ?, ?, ?)"), + ( + block.height, + block.header_hash.hex(), + proof_hash.hex(), + bytes(block.header), + ), + ) + await cursor_2.close() + await self.db.commit() + + async def get_block(self, header_hash: bytes32) -> Optional[FullBlock]: + cursor = await self.db.execute( + "SELECT * from blocks WHERE header_hash=?", (header_hash.hex(),) + ) + row = await cursor.fetchone() + await cursor.close() + if row is not None: + return FullBlock.from_bytes(row[2]) + return None + + async def get_blocks_at(self, heights: List[uint32]) -> List[FullBlock]: + if len(heights) == 0: + return [] + + heights_db = tuple(heights) + formatted_str = ( + f'SELECT * from blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)' + ) + cursor = await self.db.execute(formatted_str, heights_db) + rows = await cursor.fetchall() + await cursor.close() + blocks: List[FullBlock] = [] + for row in rows: + blocks.append(FullBlock.from_bytes(row[2])) + return blocks + + async def get_headers(self) -> List[Header]: + cursor = await self.db.execute("SELECT * from headers") + rows = await cursor.fetchall() + await cursor.close() + return [Header.from_bytes(row[3]) for row in rows] + + async def get_proof_hashes(self) -> Dict[bytes32, bytes32]: + cursor = await self.db.execute("SELECT header_hash, proof_hash from headers") + rows = await cursor.fetchall() + await cursor.close() + return {bytes.fromhex(row[0]): bytes.fromhex(row[1]) for row in rows} diff --git a/src/full_node/blockchain.py b/src/full_node/blockchain.py index 33b0cb95..2aef87fc 100644 --- a/src/full_node/blockchain.py +++ b/src/full_node/blockchain.py @@ -16,7 +16,7 @@ from src.consensus.pot_iterations import ( calculate_min_iters_from_iterations, calculate_iterations_quality, ) -from src.full_node.store import FullNodeStore +from src.full_node.block_store import BlockStore from src.types.condition_opcodes import ConditionOpcode from src.types.condition_var_pair import ConditionVarPair @@ -75,15 +75,18 @@ class Blockchain: # Genesis block genesis: FullBlock # Unspent Store - unspent_store: CoinStore + coin_store: CoinStore # Store - store: FullNodeStore + block_store: BlockStore # Coinbase freeze period coinbase_freeze: uint32 + # Lock to prevent simultaneous reads and writes + lock: asyncio.Lock + @staticmethod async def create( - unspent_store: CoinStore, store: FullNodeStore, override_constants: Dict = {}, + coin_store: CoinStore, block_store: BlockStore, override_constants: Dict = {}, ): """ Initializes a blockchain with the header blocks from disk, assuming they have all been @@ -91,6 +94,7 @@ class Blockchain: in the consensus constants config. """ self = Blockchain() + self.lock = asyncio.Lock() # External lock handled by full node self.constants = consensus_constants.copy() for key, value in override_constants.items(): self.constants[key] = value @@ -98,12 +102,12 @@ class Blockchain: self.height_to_hash = {} self.headers = {} - self.unspent_store = unspent_store - self.store = store + self.coin_store = coin_store + self.block_store = block_store self.genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"]) self.coinbase_freeze = self.constants["COINBASE_FREEZE_PERIOD"] - result, removed, error_code = await self.receive_block(self.genesis) + result, removed, error_code = await self.receive_block(self.genesis, sync_mode=False) if result != ReceiveBlockResult.ADDED_TO_HEAD: if error_code is not None: raise ConsensusError(error_code) @@ -126,12 +130,12 @@ class Blockchain: self.height_to_hash[header.height] = header.header_hash self.tips = [header] self.lca_block = header - await self._reconsider_heads(self.lca_block, False) + await self._reconsider_heads(self.lca_block, False, False) else: for _, header in sorted_headers: # Reconsider every single header, since we don't have LCA on disk self.height_to_hash[header.height] = header.header_hash - await self._reconsider_heads(header, False) + await self._reconsider_heads(header, False, False) assert ( self.headers[self.height_to_hash[uint32(0)]].get_hash() == self.genesis.header_hash @@ -151,10 +155,10 @@ class Blockchain: Loads headers from disk, into a list of Headers, that can be used to initialize the Blockchain class. """ - lca_hash: Optional[bytes32] = await self.store.get_lca() + lca_hash: Optional[bytes32] = await self.block_store.get_lca() seen_blocks: Dict[str, Header] = {} tips: List[Header] = [] - for header in await self.store.get_headers(): + for header in await self.block_store.get_headers(): if lca_hash is not None: if header.header_hash == lca_hash: tips = [header] @@ -185,7 +189,7 @@ class Blockchain: """ Return list of FullBlocks that are tips""" result: List[FullBlock] = [] for tip in self.tips: - block = await self.store.get_block(tip.header_hash) + block = await self.block_store.get_block(tip.header_hash) if not block: continue result.append(block) @@ -515,6 +519,7 @@ class Blockchain: block: FullBlock, pre_validated: bool = False, pos_quality_string: bytes32 = None, + sync_mode: bool = False, ) -> Tuple[ReceiveBlockResult, Optional[Header], Optional[Err]]: """ Adds a new block into the blockchain, if it's valid and connected to the current @@ -541,8 +546,8 @@ class Blockchain: self.headers[block.header_hash] = block.header # Always immediately add the block to the database, after updating blockchain state - await self.store.add_block(block) - res, header = await self._reconsider_heads(block.header, genesis) + await self.block_store.add_block(block) + res, header = await self._reconsider_heads(block.header, genesis, sync_mode) if res: return ReceiveBlockResult.ADDED_TO_HEAD, header, None else: @@ -751,7 +756,7 @@ class Blockchain: """ prev_full_block: Optional[FullBlock] if not genesis: - prev_full_block = await self.store.get_block(block.prev_header_hash) + prev_full_block = await self.block_store.get_block(block.prev_header_hash) if prev_full_block is None: return Err.DOES_NOT_EXTEND else: @@ -855,7 +860,7 @@ class Blockchain: return True, bytes(pos_quality_string) async def _reconsider_heads( - self, block: Header, genesis: bool + self, block: Header, genesis: bool, sync_mode: bool ) -> Tuple[bool, Optional[Header]]: """ When a new block is added, this is called, to check if the new block is heavier @@ -868,11 +873,11 @@ class Blockchain: self.tips.sort(key=lambda b: b.weight, reverse=True) # This will loop only once removed = self.tips.pop() - await self._reconsider_lca(genesis) + await self._reconsider_lca(genesis, sync_mode) return True, removed return False, None - async def _reconsider_lca(self, genesis: bool): + async def _reconsider_lca(self, genesis: bool, sync_mode: bool): """ Update the least common ancestor of the heads. This is useful, since we can just assume there is one block per height before the LCA (and use the height_to_hash dict). @@ -894,38 +899,38 @@ class Blockchain: self.lca_block = cur[0] if old_lca is None: - full: Optional[FullBlock] = await self.store.get_block( + full: Optional[FullBlock] = await self.block_store.get_block( self.lca_block.header_hash ) assert full is not None - await self.unspent_store.new_lca(full) + await self.coin_store.new_lca(full) await self._create_diffs_for_tips(self.lca_block) if not genesis: - await self.store.set_lca(self.lca_block.header_hash) + await self.block_store.set_lca(self.lca_block.header_hash) # If LCA changed update the unspent store elif old_lca.header_hash != self.lca_block.header_hash: # New LCA is lower height but not the a parent of old LCA (Reorg) fork_h = self._find_fork_point_in_chain(old_lca, self.lca_block) # Rollback to fork - await self.unspent_store.rollback_lca_to_block(fork_h) + await self.coin_store.rollback_lca_to_block(fork_h) # Add blocks between fork point and new lca fork_hash = self.height_to_hash[fork_h] fork_head = self.headers[fork_hash] await self._from_fork_to_lca(fork_head, self.lca_block) - if not self.store.get_sync_mode(): + if not sync_mode: await self.recreate_diff_stores() if not genesis: - await self.store.set_lca(self.lca_block.header_hash) + await self.block_store.set_lca(self.lca_block.header_hash) else: # If LCA has not changed just update the difference - self.unspent_store.nuke_diffs() + self.coin_store.nuke_diffs() # Create DiffStore await self._create_diffs_for_tips(self.lca_block) async def recreate_diff_stores(self): # Nuke DiffStore - self.unspent_store.nuke_diffs() + self.coin_store.nuke_diffs() # Create DiffStore await self._create_diffs_for_tips(self.lca_block) @@ -981,7 +986,7 @@ class Blockchain: while True: if tip_hash == target.header_hash: break - full = await self.store.get_block(tip_hash) + full = await self.block_store.get_block(tip_hash) if full is None: return blocks.append(full) @@ -989,23 +994,23 @@ class Blockchain: if len(blocks) == 0: return blocks.reverse() - await self.unspent_store.new_heads(blocks) + await self.coin_store.new_heads(blocks) async def _from_fork_to_lca(self, fork_point: Header, lca: Header): - """ Selects blocks between fork_point and LCA, and then adds them to unspent_store. """ + """ Selects blocks between fork_point and LCA, and then adds them to coin_store. """ blocks: List[FullBlock] = [] tip_hash: bytes32 = lca.header_hash while True: if tip_hash == fork_point.header_hash: break - full = await self.store.get_block(tip_hash) + full = await self.block_store.get_block(tip_hash) if not full: return blocks.append(full) tip_hash = full.prev_header_hash blocks.reverse() - await self.unspent_store.add_lcas(blocks) + await self.coin_store.add_lcas(blocks) def _validate_merkle_root( self, @@ -1127,7 +1132,7 @@ class Blockchain: additions_since_fork: Dict[bytes32, Tuple[Coin, uint32]] = {} removals_since_fork: Set[bytes32] = set() coinbases_since_fork: Dict[bytes32, uint32] = {} - curr: Optional[FullBlock] = await self.store.get_block(block.prev_header_hash) + curr: Optional[FullBlock] = await self.block_store.get_block(block.prev_header_hash) assert curr is not None log.info(f"curr.height is: {curr.height}, fork height is: {fork_h}") while curr.height > fork_h: @@ -1146,7 +1151,7 @@ class Blockchain: ) coinbases_since_fork[curr.header.data.coinbase.name()] = curr.height coinbases_since_fork[curr.header.data.fees_coin.name()] = curr.height - curr = await self.store.get_block(curr.prev_header_hash) + curr = await self.block_store.get_block(curr.prev_header_hash) assert curr is not None removal_coin_records: Dict[bytes32, CoinRecord] = {} @@ -1160,7 +1165,7 @@ class Blockchain: removal_coin_records[new_unspent.name] = new_unspent else: assert prev_header is not None - unspent = await self.unspent_store.get_coin_record(rem, prev_header) + unspent = await self.coin_store.get_coin_record(rem, prev_header) if unspent is not None and unspent.confirmed_block_index <= fork_h: # Spending something in the current chain, confirmed before fork # (We ignore all coins confirmed after fork) diff --git a/src/full_node/full_node.py b/src/full_node/full_node.py index 8ccfce57..55b7aa7b 100644 --- a/src/full_node/full_node.py +++ b/src/full_node/full_node.py @@ -4,6 +4,7 @@ import logging import time from asyncio import Event from typing import AsyncGenerator, List, Optional, Tuple, Dict +import aiosqlite from chiabip158 import PyBIP158 from chiapos import Verifier @@ -45,29 +46,40 @@ from src.util.api_decorators import api_request from src.util.ints import uint32, uint64, uint128 from src.util.errors import Err, ConsensusError from src.types.mempool_inclusion_status import MempoolInclusionStatus +from src.util.default_root import DEFAULT_ROOT_PATH +from src.util.path import mkdir, path_from_root +from src.full_node.block_store import BlockStore +from src.full_node.full_node_store import FullNodeStore +from src.full_node.sync_store import SyncStore OutboundMessageGenerator = AsyncGenerator[OutboundMessage, None] class FullNode: - def __init__( - self, - store: FullNodeStore, - blockchain: Blockchain, + block_store: BlockStore + full_node_store: FullNodeStore + sync_store: SyncStore + coin_store: CoinStore + mempool_manager: MempoolManager + connection: aiosqlite.Connection + blockchain: Blockchain + config: Dict + server: Optional[ChiaServer] + log: logging.Logger + constants: Dict + _shut_down: bool + + @staticmethod + async def create( config: Dict, - mempool_manager: MempoolManager, - coin_store: CoinStore, name: str = None, override_constants=None, ): + self = FullNode() - self.config: Dict = config - self.store: FullNodeStore = store - self.blockchain: Blockchain = blockchain - self.mempool_manager: MempoolManager = mempool_manager + self.config = config + self.server = None self._shut_down = False # Set to true to close all infinite loops - self.server: Optional[ChiaServer] = None - self.coin_store: CoinStore = coin_store self.constants = consensus_constants.copy() if override_constants: for key, value in override_constants.items(): @@ -77,6 +89,27 @@ class FullNode: else: self.log = logging.getLogger(__name__) + root_path = DEFAULT_ROOT_PATH + db_path = path_from_root(root_path, config["database_path"]) + mkdir(db_path.parent) + + # create the store (db) and full node instance + self.connection = await aiosqlite.connect(db_path) + self.block_store = await BlockStore.create(self.connection) + self.full_node_store = await FullNodeStore.create(self.connection) + self.sync_store = await SyncStore.create(self.connection) + genesis: FullBlock = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"]) + await self.block_store.add_block(genesis) + self.coin_store = await CoinStore.create(self.connection) + + self.log.info("Initializing blockchain from disk") + self.blockchain = await Blockchain.create(self.coin_store, self.block_store) + self.log.info("Blockchain initialized") + + self.mempool_manager = MempoolManager(self.coin_store) + await self.mempool_manager.new_tips(await self.blockchain.get_full_tips()) + + def _set_server(self, server: ChiaServer): self.server = server @@ -88,10 +121,10 @@ class FullNode: estimated proof of time rate, so farmer can calulate which proofs are good. """ requests: List[farmer_protocol.ProofOfSpaceFinalized] = [] - async with self.store.lock: + async with self.blockchain.lock: tips: List[Header] = self.blockchain.get_current_tips() for tip in tips: - full_tip: Optional[FullBlock] = await self.store.get_block( + full_tip: Optional[FullBlock] = await self.block_store.get_block( tip.header_hash ) assert full_tip is not None @@ -109,7 +142,7 @@ class FullNode: challenge_hash, tip.height, tip.weight, difficulty ) ) - full_block: Optional[FullBlock] = await self.store.get_block( + full_block: Optional[FullBlock] = await self.block_store.get_block( tips[0].header_hash ) assert full_block is not None @@ -139,7 +172,7 @@ class FullNode: pos_info_requests: List[timelord_protocol.ProofOfSpaceInfo] = [] tips: List[Header] = self.blockchain.get_current_tips() tips_blocks: List[Optional[FullBlock]] = [ - await self.store.get_block(tip.header_hash) for tip in tips + await self.block_store.get_block(tip.header_hash) for tip in tips ] for tip in tips_blocks: assert tip is not None @@ -152,7 +185,7 @@ class FullNode: tip_hashes = [tip.header_hash for tip in tips] tip_infos = [ tup[0] - for tup in list((self.store.get_unfinished_blocks()).items()) + for tup in list((self.full_node_store.get_unfinished_blocks()).items()) if tup[1].prev_header_hash in tip_hashes ] for chall, iters in tip_infos: @@ -238,8 +271,9 @@ class FullNode: self.introducer_task = asyncio.create_task(introducer_client()) - def _shutdown(self): + async def _shutdown(self): self._shut_down = True + await self.connection.close() async def _sync(self) -> OutboundMessageGenerator: """ @@ -253,7 +287,7 @@ class FullNode: """ self.log.info("Starting to perform sync with peers.") self.log.info("Waiting to receive tips from peers.") - self.store.set_waiting_for_tips(True) + self.sync_store.set_waiting_for_tips(True) # TODO: better way to tell that we have finished receiving tips # TODO: fix DOS issue. Attacker can request syncing to an invalid blockchain await asyncio.sleep(5) @@ -264,11 +298,11 @@ class FullNode: # Based on responses from peers about the current heads, see which head is the heaviest # (similar to longest chain rule). - self.store.set_waiting_for_tips(False) + self.sync_store.set_waiting_for_tips(False) potential_tips: List[ Tuple[bytes32, FullBlock] - ] = self.store.get_potential_tips_tuples() + ] = self.sync_store.get_potential_tips_tuples() self.log.info(f"Have collected {len(potential_tips)} potential tips") for header_hash, potential_tip_block in potential_tips: if potential_tip_block.proof_of_time is None: @@ -291,9 +325,9 @@ class FullNode: ) for height in range(0, tip_block.height + 1): - self.store.set_potential_headers_received(uint32(height), Event()) - self.store.set_potential_blocks_received(uint32(height), Event()) - self.store.set_potential_hashes_received(Event()) + self.sync_store.set_potential_headers_received(uint32(height), Event()) + self.sync_store.set_potential_blocks_received(uint32(height), Event()) + self.sync_store.set_potential_hashes_received(Event()) timeout = 200 sleep_interval = 10 @@ -312,7 +346,7 @@ class FullNode: Delivery.RANDOM, ) try: - phr = self.store.get_potential_hashes_received() + phr = self.sync_store.get_potential_hashes_received() assert phr is not None await asyncio.wait_for( phr.wait(), timeout=sleep_interval, @@ -323,7 +357,7 @@ class FullNode: self.log.warning("Did not receive desired header hashes") # Finding the fork point allows us to only download headers and blocks from the fork point - header_hashes = self.store.get_potential_hashes() + header_hashes = self.sync_store.get_potential_hashes() fork_point_height: uint32 = self.blockchain.find_fork_point_alternate_chain( header_hashes ) @@ -368,7 +402,7 @@ class FullNode: blocks_missing = any( [ not ( - self.store.get_potential_headers_received(uint32(h)) + self.sync_store.get_potential_headers_received(uint32(h)) ).is_set() for h in range(batch_start, batch_end) ] @@ -400,7 +434,7 @@ class FullNode: # Wait for the first batch (the next "max_blocks_to_send" blocks to arrive) awaitables = [ - (self.store.get_potential_headers_received(uint32(height))).wait() + (self.sync_store.get_potential_headers_received(uint32(height))).wait() for height in range(height_checkpoint, end_height) ] future = asyncio.gather(*awaitables, return_exceptions=True) @@ -416,7 +450,7 @@ class FullNode: self.log.info(f"Did not receive desired header blocks") for h in range(fork_point_height + 1, tip_height + 1): - header = self.store.get_potential_header(uint32(h)) + header = self.sync_store.get_potential_header(uint32(h)) assert header is not None headers.append(header) @@ -430,7 +464,7 @@ class FullNode: f"Validated weight of headers. Downloaded {len(headers)} headers, tip height {tip_height}" ) assert tip_height == fork_point_height + len(headers) - self.store.clear_potential_headers() + self.sync_store.clear_potential_headers() headers.clear() # Download blocks in batches, and verify them as they come in. We download a few batches ahead, @@ -469,7 +503,7 @@ class FullNode: blocks_missing = any( [ not ( - self.store.get_potential_blocks_received(uint32(h)) + self.sync_store.get_potential_blocks_received(uint32(h)) ).is_set() for h in range(batch_start, batch_end) ] @@ -500,7 +534,7 @@ class FullNode: # Wait for the first batch (the next "max_blocks_to_send" blocks to arrive) awaitables = [ - (self.store.get_potential_blocks_received(uint32(height))).wait() + (self.sync_store.get_potential_blocks_received(uint32(height))).wait() for height in range(height_checkpoint, end_height) ] future = asyncio.gather(*awaitables, return_exceptions=True) @@ -518,7 +552,7 @@ class FullNode: # Verifies this batch, which we are guaranteed to have (since we broke from the above loop) blocks = [] for height in range(height_checkpoint, end_height): - b: Optional[FullBlock] = await self.store.get_potential_block( + b: Optional[FullBlock] = await self.sync_store.get_potential_block( uint32(height) ) assert b is not None @@ -530,7 +564,7 @@ class FullNode: for height in range(height_checkpoint, end_height): if self._shut_down: return - block: Optional[FullBlock] = await self.store.get_potential_block( + block: Optional[FullBlock] = await self.sync_store.get_potential_block( uint32(height) ) assert block is not None @@ -539,12 +573,12 @@ class FullNode: validated, pos = prevalidate_results[index] index += 1 - async with self.store.lock: + async with self.blockchain.lock: ( result, header_block, error_code, - ) = await self.blockchain.receive_block(block, validated, pos) + ) = await self.blockchain.receive_block(block, validated, pos, sync_mode=True) if ( result == ReceiveBlockResult.INVALID_BLOCK or result == ReceiveBlockResult.DISCONNECTED_BLOCK @@ -553,14 +587,11 @@ class FullNode: raise ConsensusError(error_code, block.header_hash) raise RuntimeError(f"Invalid block {block.header_hash}") - # Always immediately add the block to the database, after updating blockchain state - await self.store.add_block(block) - assert ( max([h.height for h in self.blockchain.get_current_tips()]) >= height ) - self.store.set_proof_of_time_estimate_ips( + self.full_node_store.set_proof_of_time_estimate_ips( self.blockchain.get_next_min_iters(block) // ( self.constants["BLOCK_TIME_TARGET"] @@ -582,11 +613,11 @@ class FullNode: Finalize sync by setting sync mode to False, clearing all sync information, and adding any final blocks that we have finalized recently. """ - potential_fut_blocks = (self.store.get_potential_future_blocks()).copy() - self.store.set_sync_mode(False) + potential_fut_blocks = (self.sync_store.get_potential_future_blocks()).copy() + self.full_node_store.set_sync_mode(False) - async with self.store.lock: - await self.store.clear_sync_info() + async with self.blockchain.lock: + await self.sync_store.clear_sync_info() await self.blockchain.recreate_diff_stores() for block in potential_fut_blocks: @@ -646,7 +677,7 @@ class FullNode: Requests a full transaction if we haven't seen it previously, and if the fees are enough. """ # Ignore if syncing - if self.store.get_sync_mode(): + if self.full_node_store.get_sync_mode(): return # Ignore if already seen if self.mempool_manager.seen(transaction.transaction_id): @@ -668,7 +699,7 @@ class FullNode: ) -> OutboundMessageGenerator: """ Peer has requested a full transaction from us. """ # Ignore if syncing - if self.store.get_sync_mode(): + if self.full_node_store.get_sync_mode(): return spend_bundle = self.mempool_manager.get_spendbundle(request.transaction_id) if spend_bundle is None: @@ -698,9 +729,9 @@ class FullNode: If tx is added to mempool, send tx_id to others. (new_transaction) """ # Ignore if syncing - if self.store.get_sync_mode(): + if self.full_node_store.get_sync_mode(): return - async with self.store.lock: + async with self.blockchain.lock: cost, status, error = await self.mempool_manager.add_spendbundle( tx.transaction ) @@ -739,7 +770,7 @@ class FullNode: ) -> OutboundMessageGenerator: # If we don't have an unfinished block for this PoT, we don't care about it if ( - self.store.get_unfinished_block( + self.full_node_store.get_unfinished_block( ( new_proof_of_time.challenge_hash, new_proof_of_time.number_of_iterations, @@ -750,7 +781,7 @@ class FullNode: return # If we already have the PoT in a finished block, return - blocks: List[FullBlock] = await self.store.get_blocks_at( + blocks: List[FullBlock] = await self.block_store.get_blocks_at( [new_proof_of_time.height] ) for block in blocks: @@ -763,7 +794,7 @@ class FullNode: ): return - self.store.add_proof_of_time_heights( + self.full_node_store.add_proof_of_time_heights( (new_proof_of_time.challenge_hash, new_proof_of_time.number_of_iterations), new_proof_of_time.height, ) @@ -784,7 +815,7 @@ class FullNode: async def request_proof_of_time( self, request_proof_of_time: full_node_protocol.RequestProofOfTime ) -> OutboundMessageGenerator: - blocks: List[FullBlock] = await self.store.get_blocks_at( + blocks: List[FullBlock] = await self.block_store.get_blocks_at( [request_proof_of_time.height] ) for block in blocks: @@ -822,7 +853,7 @@ class FullNode: we can complete it. Otherwise, we just verify and propagate the proof. """ if ( - self.store.get_unfinished_block( + self.full_node_store.get_unfinished_block( ( respond_proof_of_time.proof.challenge_hash, respond_proof_of_time.proof.number_of_iterations, @@ -830,7 +861,7 @@ class FullNode: ) is not None ): - height: Optional[uint32] = self.store.get_proof_of_time_heights( + height: Optional[uint32] = self.full_node_store.get_proof_of_time_heights( ( respond_proof_of_time.proof.challenge_hash, respond_proof_of_time.proof.number_of_iterations, @@ -867,7 +898,7 @@ class FullNode: self, new_compact_proof_of_time: full_node_protocol.NewCompactProofOfTime ) -> OutboundMessageGenerator: # If we already have the compact PoT in a finished block, return - blocks: List[FullBlock] = await self.store.get_blocks_at( + blocks: List[FullBlock] = await self.block_store.get_blocks_at( [new_compact_proof_of_time.height] ) for block in blocks: @@ -894,7 +925,7 @@ class FullNode: request_compact_proof_of_time: full_node_protocol.RequestCompactProofOfTime, ) -> OutboundMessageGenerator: # If we already have the compact PoT in a finished block, return it - blocks: List[FullBlock] = await self.store.get_blocks_at( + blocks: List[FullBlock] = await self.block_store.get_blocks_at( [request_compact_proof_of_time.height] ) for block in blocks: @@ -936,7 +967,7 @@ class FullNode: A proof of time, received by a peer full node. If we have the rest of the block, we can complete it. Otherwise, we just verify and propagate the proof. """ - height: Optional[uint32] = self.store.get_proof_of_time_heights( + height: Optional[uint32] = self.full_node_store.get_proof_of_time_heights( ( respond_compact_proof_of_time.proof.challenge_hash, respond_compact_proof_of_time.proof.number_of_iterations, @@ -945,7 +976,7 @@ class FullNode: if height is None: return - blocks: List[FullBlock] = await self.store.get_blocks_at([height]) + blocks: List[FullBlock] = await self.block_store.get_blocks_at([height]) for block in blocks: assert block.proof_of_time is not None if ( @@ -965,7 +996,7 @@ class FullNode: block.transactions_generator, block.transactions_filter, ) - await self.store.add_block(block_new) + await self.block_store.add_block(block_new) yield OutboundMessage( NodeType.FULL_NODE, Message( @@ -997,14 +1028,14 @@ class FullNode: new_unfinished_block.previous_header_hash ): return - prev_block: Optional[FullBlock] = await self.store.get_block( + prev_block: Optional[FullBlock] = await self.block_store.get_block( new_unfinished_block.previous_header_hash ) if prev_block is not None: challenge = self.blockchain.get_challenge(prev_block) if challenge is not None: if ( - self.store.get_unfinished_block( + self.full_node_store.get_unfinished_block( ( challenge.get_hash(), new_unfinished_block.number_of_iterations, @@ -1029,7 +1060,7 @@ class FullNode: async def request_unfinished_block( self, request_unfinished_block: full_node_protocol.RequestUnfinishedBlock ) -> OutboundMessageGenerator: - for _, block in self.store.get_unfinished_blocks().items(): + for _, block in self.full_node_store.get_unfinished_blocks().items(): if block.header_hash == request_unfinished_block.header_hash: yield OutboundMessage( NodeType.FULL_NODE, @@ -1040,7 +1071,7 @@ class FullNode: Delivery.RESPOND, ) return - fetched: Optional[FullBlock] = await self.store.get_block( + fetched: Optional[FullBlock] = await self.block_store.get_block( request_unfinished_block.header_hash ) if fetched is not None: @@ -1074,18 +1105,18 @@ class FullNode: block = respond_unfinished_block.block # Adds the unfinished block to seen, and check if it's seen before, to prevent # processing it twice - if self.store.seen_unfinished_block(block.header_hash): + if self.full_node_store.seen_unfinished_block(block.header_hash): return if not self.blockchain.is_child_of_head(block): return - prev_full_block: Optional[FullBlock] = await self.store.get_block( + prev_full_block: Optional[FullBlock] = await self.block_store.get_block( block.prev_header_hash ) assert prev_full_block is not None - async with self.store.lock: + async with self.blockchain.lock: ( error_code, iterations_needed, @@ -1100,13 +1131,13 @@ class FullNode: challenge_hash = challenge.get_hash() if ( - self.store.get_unfinished_block((challenge_hash, iterations_needed)) + self.full_node_store.get_unfinished_block((challenge_hash, iterations_needed)) is not None ): return expected_time: uint64 = uint64( - int(iterations_needed / (self.store.get_proof_of_time_estimate_ips())) + int(iterations_needed / (self.full_node_store.get_proof_of_time_estimate_ips())) ) if expected_time > self.constants["PROPAGATION_DELAY_THRESHOLD"]: @@ -1114,13 +1145,13 @@ class FullNode: # If this block is slow, sleep to allow faster blocks to come out first await asyncio.sleep(5) - leader: Tuple[uint32, uint64] = self.store.get_unfinished_block_leader() + leader: Tuple[uint32, uint64] = self.full_node_store.get_unfinished_block_leader() if leader is None or block.height > leader[0]: self.log.info( f"This is the first unfinished block at height {block.height}, so propagate." ) # If this is the first block we see at this height, propagate - self.store.set_unfinished_block_leader((block.height, expected_time)) + self.full_node_store.set_unfinished_block_leader((block.height, expected_time)) elif block.height == leader[0]: if expected_time > leader[1] + self.constants["PROPAGATION_THRESHOLD"]: # If VDF is expected to finish X seconds later than the best, don't propagate @@ -1131,13 +1162,13 @@ class FullNode: elif expected_time < leader[1]: self.log.info(f"New best unfinished block at height {block.height}") # If this will be the first block to finalize, update our leader - self.store.set_unfinished_block_leader((leader[0], expected_time)) + self.full_node_store.set_unfinished_block_leader((leader[0], expected_time)) else: # If we have seen an unfinished block at a greater or equal height, don't propagate self.log.info(f"Unfinished block at old height, so don't propagate") return - self.store.add_unfinished_block((challenge_hash, iterations_needed), block) + self.full_node_store.add_unfinished_block((challenge_hash, iterations_needed), block) timelord_request = timelord_protocol.ProofOfSpaceInfo( challenge_hash, iterations_needed @@ -1183,8 +1214,8 @@ class FullNode: self, all_header_hashes: full_node_protocol.AllHeaderHashes ) -> OutboundMessageGenerator: assert len(all_header_hashes.header_hashes) > 0 - self.store.set_potential_hashes(all_header_hashes.header_hashes) - phr = self.store.get_potential_hashes_received() + self.sync_store.set_potential_hashes(all_header_hashes.header_hashes) + phr = self.sync_store.get_potential_hashes_received() assert phr is not None phr.set() for _ in []: # Yields nothing @@ -1197,7 +1228,7 @@ class FullNode: """ A peer requests a list of header blocks, by height. Used for syncing or light clients. """ - full_block: Optional[FullBlock] = await self.store.get_block( + full_block: Optional[FullBlock] = await self.block_store.get_block( request.header_hash ) if full_block is not None: @@ -1229,8 +1260,8 @@ class FullNode: Receive header blocks from a peer. """ self.log.info(f"Received header block {request.header_block.height}.") - self.store.add_potential_header(request.header_block) - (self.store.get_potential_headers_received(request.header_block.height)).set() + self.sync_store.add_potential_header(request.header_block) + (self.sync_store.get_potential_headers_received(request.header_block.height)).set() for _ in []: # Yields nothing yield _ @@ -1240,7 +1271,7 @@ class FullNode: self, request: full_node_protocol.RejectHeaderBlockRequest ) -> OutboundMessageGenerator: self.log.warning(f"Reject header block request, {request}") - if self.store.get_sync_mode(): + if self.full_node_store.get_sync_mode(): yield OutboundMessage(NodeType.FULL_NODE, Message("", None), Delivery.CLOSE) for _ in []: yield _ @@ -1267,7 +1298,7 @@ class FullNode: # Retrieves the correct tip for the challenge tips: List[Header] = self.blockchain.get_current_tips() tips_blocks: List[Optional[FullBlock]] = [ - await self.store.get_block(tip.header_hash) for tip in tips + await self.block_store.get_block(tip.header_hash) for tip in tips ] target_tip_block: Optional[FullBlock] = None target_tip: Optional[Header] = None @@ -1286,7 +1317,7 @@ class FullNode: assert target_tip is not None # Grab best transactions from Mempool for given tip target - async with self.store.lock: + async with self.blockchain.lock: spend_bundle: Optional[ SpendBundle ] = await self.mempool_manager.create_bundle_for_tip(target_tip) @@ -1403,7 +1434,7 @@ class FullNode: block_header_data_hash: bytes32 = block_header_data.get_hash() # Stores this block so we can submit it to the blockchain after it's signed by harvester - self.store.add_candidate_block( + self.full_node_store.add_candidate_block( proof_of_space_hash, solution_program, encoded_filter, @@ -1430,7 +1461,7 @@ class FullNode: """ candidate: Optional[ Tuple[Optional[Program], Optional[bytes], HeaderData, ProofOfSpace] - ] = self.store.get_candidate_block(header_signature.pos_hash) + ] = self.full_node_store.get_candidate_block(header_signature.pos_hash) if candidate is None: self.log.warning( f"PoS hash {header_signature.pos_hash} not found in database" @@ -1468,7 +1499,7 @@ class FullNode: request.proof.number_of_iterations, ) - unfinished_block_obj: Optional[FullBlock] = self.store.get_unfinished_block( + unfinished_block_obj: Optional[FullBlock] = self.full_node_store.get_unfinished_block( dict_key ) if not unfinished_block_obj: @@ -1485,8 +1516,8 @@ class FullNode: unfinished_block_obj.transactions_filter, ) - if self.store.get_sync_mode(): - self.store.add_potential_future_block(new_full_block) + if self.full_node_store.get_sync_mode(): + self.sync_store.add_potential_future_block(new_full_block) else: async for msg in self.respond_block( full_node_protocol.RespondBlock(new_full_block) @@ -1497,7 +1528,7 @@ class FullNode: async def request_block( self, request_block: full_node_protocol.RequestBlock ) -> OutboundMessageGenerator: - block: Optional[FullBlock] = await self.store.get_block( + block: Optional[FullBlock] = await self.block_store.get_block( request_block.header_hash ) if block is not None: @@ -1522,21 +1553,21 @@ class FullNode: """ Receive a full block from a peer full node (or ourselves). """ - if self.store.get_sync_mode(): + if self.full_node_store.get_sync_mode(): # This is a tip sent to us by another peer - if self.store.get_waiting_for_tips(): + if self.sync_store.get_waiting_for_tips(): # Add the block to our potential tips list - self.store.add_potential_tip(respond_block.block) + self.sync_store.add_potential_tip(respond_block.block) return # This is a block we asked for during sync - await self.store.add_potential_block(respond_block.block) + await self.sync_store.add_potential_block(respond_block.block) if ( - self.store.get_sync_mode() - and respond_block.block.height in self.store.potential_blocks_received + self.full_node_store.get_sync_mode() + and respond_block.block.height in self.sync_store.potential_blocks_received ): # If we are still in sync mode, set it - self.store.get_potential_blocks_received( + self.sync_store.get_potential_blocks_received( respond_block.block.height ).set() return @@ -1552,10 +1583,10 @@ class FullNode: val, pos = prevalidate_block[0] prev_lca = self.blockchain.lca_block - async with self.store.lock: + async with self.blockchain.lock: # Tries to add the block to the blockchain added, replaced, error_code = await self.blockchain.receive_block( - respond_block.block, val, pos + respond_block.block, val, pos, sync_mode=False ) if added == ReceiveBlockResult.ADDED_TO_HEAD: await self.mempool_manager.new_tips( @@ -1581,12 +1612,12 @@ class FullNode: respond_block.block.height > tip_height + self.config["sync_blocks_behind_threshold"] ): - async with self.store.lock: - if self.store.get_sync_mode(): + async with self.blockchain.lock: + if self.full_node_store.get_sync_mode(): return - await self.store.clear_sync_info() - self.store.add_potential_tip(respond_block.block) - self.store.set_sync_mode(True) + await self.sync_store.clear_sync_info() + self.sync_store.add_potential_tip(respond_block.block) + self.full_node_store.set_sync_mode(True) self.log.info( f"We are too far behind this block. Our height is {tip_height} and block is at " f"{respond_block.block.height}" @@ -1615,7 +1646,7 @@ class FullNode: respond_block.block.prev_header_hash, ), ) - self.store.add_disconnected_block(respond_block.block) + self.full_node_store.add_disconnected_block(respond_block.block) yield OutboundMessage(NodeType.FULL_NODE, msg, Delivery.RESPOND) return elif added == ReceiveBlockResult.ADDED_TO_HEAD: @@ -1633,8 +1664,8 @@ class FullNode: / self.constants["MIN_ITERS_PROPORTION"] ) self.log.info(f"Difficulty {difficulty} IPS {next_vdf_ips}") - if next_vdf_ips != self.store.get_proof_of_time_estimate_ips(): - self.store.set_proof_of_time_estimate_ips(next_vdf_ips) + if next_vdf_ips != self.full_node_store.get_proof_of_time_estimate_ips(): + self.full_node_store.set_proof_of_time_estimate_ips(next_vdf_ips) rate_update = farmer_protocol.ProofOfTimeRate(next_vdf_ips) self.log.info(f"Sending proof of time rate {next_vdf_ips}") yield OutboundMessage( @@ -1642,7 +1673,7 @@ class FullNode: Message("proof_of_time_rate", rate_update), Delivery.BROADCAST, ) - self.store.clear_seen_unfinished_blocks() + self.full_node_store.clear_seen_unfinished_blocks() challenge: Optional[Challenge] = self.blockchain.get_challenge( respond_block.block @@ -1714,7 +1745,7 @@ class FullNode: raise RuntimeError(f"Invalid result from receive_block {added}") # This code path is reached if added == ADDED_AS_ORPHAN or ADDED_TO_HEAD - next_block: Optional[FullBlock] = self.store.get_disconnected_block_by_prev( + next_block: Optional[FullBlock] = self.full_node_store.get_disconnected_block_by_prev( respond_block.block.header_hash ) @@ -1728,16 +1759,16 @@ class FullNode: # Removes all temporary data for old blocks lowest_tip = min(tip.height for tip in self.blockchain.get_current_tips()) clear_height = uint32(max(0, lowest_tip - 30)) - self.store.clear_candidate_blocks_below(clear_height) - self.store.clear_unfinished_blocks_below(clear_height) - self.store.clear_disconnected_blocks_below(clear_height) + self.full_node_store.clear_candidate_blocks_below(clear_height) + self.full_node_store.clear_unfinished_blocks_below(clear_height) + self.full_node_store.clear_disconnected_blocks_below(clear_height) @api_request async def reject_block_request( self, reject: full_node_protocol.RejectBlockRequest ) -> OutboundMessageGenerator: self.log.warning(f"Rejected block request {reject}") - if self.store.get_sync_mode(): + if self.full_node_store.get_sync_mode(): yield OutboundMessage(NodeType.FULL_NODE, Message("", None), Delivery.CLOSE) for _ in []: yield _ @@ -1806,12 +1837,12 @@ class FullNode: self, tx: wallet_protocol.SendTransaction ) -> OutboundMessageGenerator: # Ignore if syncing - if self.store.get_sync_mode(): + if self.full_node_store.get_sync_mode(): cost = None status = MempoolInclusionStatus.FAILED error: Optional[Err] = Err.UNKNOWN else: - async with self.store.lock: + async with self.blockchain.lock: cost, status, error = await self.mempool_manager.add_spendbundle( tx.transaction ) @@ -1858,7 +1889,7 @@ class FullNode: async def request_all_proof_hashes( self, request: wallet_protocol.RequestAllProofHashes ) -> OutboundMessageGenerator: - proof_hashes_map = await self.store.get_proof_hashes() + proof_hashes_map = await self.block_store.get_proof_hashes() curr = self.blockchain.lca_block hashes: List[Tuple[bytes32, Optional[uint64], Optional[uint64]]] = [] @@ -1910,7 +1941,7 @@ class FullNode: Delivery.RESPOND, ) return - block: Optional[FullBlock] = await self.store.get_block(header_hash) + block: Optional[FullBlock] = await self.block_store.get_block(header_hash) header_hash_again: Optional[bytes32] = self.blockchain.height_to_hash.get( request.starting_height, None ) @@ -1947,7 +1978,7 @@ class FullNode: async def request_header( self, request: wallet_protocol.RequestHeader ) -> OutboundMessageGenerator: - full_block: Optional[FullBlock] = await self.store.get_block( + full_block: Optional[FullBlock] = await self.block_store.get_block( request.header_hash ) if full_block is not None: @@ -1975,7 +2006,7 @@ class FullNode: async def request_removals( self, request: wallet_protocol.RequestRemovals ) -> OutboundMessageGenerator: - block: Optional[FullBlock] = await self.store.get_block(request.header_hash) + block: Optional[FullBlock] = await self.block_store.get_block(request.header_hash) if ( block is None or block.height != request.height @@ -2045,7 +2076,7 @@ class FullNode: async def request_additions( self, request: wallet_protocol.RequestAdditions ) -> OutboundMessageGenerator: - block: Optional[FullBlock] = await self.store.get_block(request.header_hash) + block: Optional[FullBlock] = await self.block_store.get_block(request.header_hash) if ( block is None or block.height != request.height @@ -2127,7 +2158,7 @@ class FullNode: async def request_generator( self, request: wallet_protocol.RequestGenerator ) -> OutboundMessageGenerator: - full_block: Optional[FullBlock] = await self.store.get_block( + full_block: Optional[FullBlock] = await self.block_store.get_block( request.header_hash ) if full_block is not None: diff --git a/src/full_node/full_node_store..py b/src/full_node/full_node_store..py new file mode 100644 index 00000000..6bd1d426 --- /dev/null +++ b/src/full_node/full_node_store..py @@ -0,0 +1,188 @@ +import asyncio +import logging +import aiosqlite +from typing import Dict, List, Optional, Tuple + +from src.types.program import Program +from src.types.full_block import FullBlock +from src.types.header import HeaderData, Header +from src.types.header_block import HeaderBlock +from src.types.proof_of_space import ProofOfSpace +from src.types.sized_bytes import bytes32 +from src.util.hash import std_hash +from src.util.ints import uint32, uint64 + +log = logging.getLogger(__name__) + + +class FullNodeStore: + db: aiosqlite.Connection + # Whether or not we are syncing + sync_mode: bool + # Current estimate of the speed of the network timelords + proof_of_time_estimate_ips: uint64 + # Proof of time heights + proof_of_time_heights: Dict[Tuple[bytes32, uint64], uint32] + # Our best unfinished block + unfinished_blocks_leader: Tuple[uint32, uint64] + # Blocks which we have created, but don't have proof of space yet, old ones are cleared + candidate_blocks: Dict[ + bytes32, + Tuple[Optional[Program], Optional[bytes], HeaderData, ProofOfSpace, uint32], + ] + # Blocks which are not finalized yet (no proof of time), old ones are cleared + unfinished_blocks: Dict[Tuple[bytes32, uint64], FullBlock] + # Header hashes of unfinished blocks that we have seen recently + seen_unfinished_blocks: set + # Blocks which we have received but our blockchain does not reach, old ones are cleared + disconnected_blocks: Dict[bytes32, FullBlock] + + @classmethod + async def create(cls, connection): + self = cls() + + self.db = connection + + await self.db.commit() + + self.sync_mode = False + self.proof_of_time_estimate_ips = uint64(10000) + self.proof_of_time_heights = {} + self.unfinished_blocks_leader = ( + uint32(0), + uint64((1 << 64) - 1), + ) + self.candidate_blocks = {} + self.unfinished_blocks = {} + self.seen_unfinished_blocks = set() + self.disconnected_blocks = {} + return self + + async def _clear_database(self): + async with self.lock: + await self.db.commit() + + def add_disconnected_block(self, block: FullBlock) -> None: + self.disconnected_blocks[block.header_hash] = block + + def get_disconnected_block_by_prev( + self, prev_header_hash: bytes32 + ) -> Optional[FullBlock]: + for _, block in self.disconnected_blocks.items(): + if block.prev_header_hash == prev_header_hash: + return block + return None + + def get_disconnected_block(self, header_hash: bytes32) -> Optional[FullBlock]: + return self.disconnected_blocks.get(header_hash, None) + + def clear_disconnected_blocks_below(self, height: uint32) -> None: + for key in list(self.disconnected_blocks.keys()): + if self.disconnected_blocks[key].height < height: + del self.disconnected_blocks[key] + + def set_sync_mode(self, sync_mode: bool) -> None: + self.sync_mode = sync_mode + + def get_sync_mode(self) -> bool: + return self.sync_mode + + def add_candidate_block( + self, + pos_hash: bytes32, + transactions_generator: Optional[Program], + transactions_filter: Optional[bytes], + header: HeaderData, + pos: ProofOfSpace, + height: uint32 = uint32(0), + ): + self.candidate_blocks[pos_hash] = ( + transactions_generator, + transactions_filter, + header, + pos, + height, + ) + + def get_candidate_block( + self, pos_hash: bytes32 + ) -> Optional[Tuple[Optional[Program], Optional[bytes], HeaderData, ProofOfSpace]]: + res = self.candidate_blocks.get(pos_hash, None) + if res is None: + return None + return (res[0], res[1], res[2], res[3]) + + def clear_candidate_blocks_below(self, height: uint32) -> None: + del_keys = [] + for key, value in self.candidate_blocks.items(): + if value[4] < height: + del_keys.append(key) + for key in del_keys: + try: + del self.candidate_blocks[key] + except KeyError: + pass + + def add_unfinished_block( + self, key: Tuple[bytes32, uint64], block: FullBlock + ) -> None: + self.unfinished_blocks[key] = block + + def get_unfinished_block(self, key: Tuple[bytes32, uint64]) -> Optional[FullBlock]: + return self.unfinished_blocks.get(key, None) + + def seen_unfinished_block(self, header_hash: bytes32) -> bool: + if header_hash in self.seen_unfinished_blocks: + return True + self.seen_unfinished_blocks.add(header_hash) + return False + + def clear_seen_unfinished_blocks(self) -> None: + self.seen_unfinished_blocks.clear() + + def get_unfinished_blocks(self) -> Dict[Tuple[bytes32, uint64], FullBlock]: + return self.unfinished_blocks.copy() + + def clear_unfinished_blocks_below(self, height: uint32) -> None: + del_keys = [] + for key, unf in self.unfinished_blocks.items(): + if unf.height < height: + del_keys.append(key) + for key in del_keys: + try: + del self.unfinished_blocks[key] + except KeyError: + pass + + def set_unfinished_block_leader(self, key: Tuple[bytes32, uint64]) -> None: + self.unfinished_blocks_leader = key + + def get_unfinished_block_leader(self) -> Tuple[bytes32, uint64]: + return self.unfinished_blocks_leader + + def set_proof_of_time_estimate_ips(self, estimate: uint64): + self.proof_of_time_estimate_ips = estimate + + def get_proof_of_time_estimate_ips(self) -> uint64: + return self.proof_of_time_estimate_ips + + def add_proof_of_time_heights( + self, challenge_iters: Tuple[bytes32, uint64], height: uint32 + ) -> None: + self.proof_of_time_heights[challenge_iters] = height + + def get_proof_of_time_heights( + self, challenge_iters: Tuple[bytes32, uint64] + ) -> Optional[uint32]: + return self.proof_of_time_heights.get(challenge_iters, None) + + def clear_proof_of_time_heights_below(self, height: uint32) -> None: + del_keys: List = [] + for key, value in self.proof_of_time_heights.items(): + if value < height: + del_keys.append(key) + for key in del_keys: + try: + del self.proof_of_time_heights[key] + except KeyError: + pass diff --git a/src/full_node/mempool_manager.py b/src/full_node/mempool_manager.py index 18d827bd..11cecd8a 100644 --- a/src/full_node/mempool_manager.py +++ b/src/full_node/mempool_manager.py @@ -32,7 +32,7 @@ log = logging.getLogger(__name__) class MempoolManager: - def __init__(self, unspent_store: CoinStore, override_constants: Dict = {}): + def __init__(self, coin_store: CoinStore, override_constants: Dict = {}): # Allow passing in custom overrides self.constants: Dict = consensus_constants.copy() for key, value in override_constants.items(): @@ -47,7 +47,7 @@ class MempoolManager: # old_mempools will contain transactions that were removed in the last 10 blocks self.old_mempools: SortedDict[uint32, Dict[bytes32, MempoolItem]] = SortedDict() - self.unspent_store = unspent_store + self.coin_store = coin_store tx_per_sec = self.constants["TX_PER_SEC"] sec_per_block = self.constants["BLOCK_TIME_TARGET"] @@ -181,7 +181,7 @@ class MempoolManager: unknown_unspent_error: bool = False removal_amount = uint64(0) for name in removal_names: - removal_record = await self.unspent_store.get_coin_record( + removal_record = await self.coin_store.get_coin_record( name, pool.header ) if removal_record is None and name not in additions_dict: diff --git a/src/full_node/store.py b/src/full_node/store.py deleted file mode 100644 index 0b9c22de..00000000 --- a/src/full_node/store.py +++ /dev/null @@ -1,404 +0,0 @@ -import asyncio -import logging -import aiosqlite -from typing import Dict, List, Optional, Tuple - -from src.types.program import Program -from src.types.full_block import FullBlock -from src.types.header import HeaderData, Header -from src.types.header_block import HeaderBlock -from src.types.proof_of_space import ProofOfSpace -from src.types.sized_bytes import bytes32 -from src.util.hash import std_hash -from src.util.ints import uint32, uint64 - -log = logging.getLogger(__name__) - - -class FullNodeStore: - db: aiosqlite.Connection - # Whether or not we are syncing - sync_mode: bool - # Whether we are waiting for tips (at the start of sync) or already syncing - waiting_for_tips: bool - # Potential new tips that we have received from others. - potential_tips: Dict[bytes32, FullBlock] - # List of all header hashes up to the tip, download up front - potential_hashes: List[bytes32] - # Header blocks received from other peers during sync - potential_headers: Dict[uint32, HeaderBlock] - # Event to signal when header hashes are received - potential_hashes_received: Optional[asyncio.Event] - # Event to signal when headers are received at each height - potential_headers_received: Dict[uint32, asyncio.Event] - # Event to signal when blocks are received at each height - potential_blocks_received: Dict[uint32, asyncio.Event] - # Blocks that we have finalized during sync, queue them up for adding after sync is done - potential_future_blocks: List[FullBlock] - # Current estimate of the speed of the network timelords - proof_of_time_estimate_ips: uint64 - # Proof of time heights - proof_of_time_heights: Dict[Tuple[bytes32, uint64], uint32] - # Our best unfinished block - unfinished_blocks_leader: Tuple[uint32, uint64] - # Blocks which we have created, but don't have proof of space yet, old ones are cleared - candidate_blocks: Dict[ - bytes32, - Tuple[Optional[Program], Optional[bytes], HeaderData, ProofOfSpace, uint32], - ] - # Blocks which are not finalized yet (no proof of time), old ones are cleared - unfinished_blocks: Dict[Tuple[bytes32, uint64], FullBlock] - # Header hashes of unfinished blocks that we have seen recently - seen_unfinished_blocks: set - # Blocks which we have received but our blockchain does not reach, old ones are cleared - disconnected_blocks: Dict[bytes32, FullBlock] - - # Lock - lock: asyncio.Lock - - @classmethod - async def create(cls, connection): - self = cls() - - # All full blocks which have been added to the blockchain. Header_hash -> block - self.db = connection - await self.db.execute( - "CREATE TABLE IF NOT EXISTS blocks(height bigint, header_hash text PRIMARY KEY, block blob)" - ) - - # Blocks received from other peers during sync, cleared after sync - await self.db.execute( - "CREATE TABLE IF NOT EXISTS potential_blocks(height bigint PRIMARY KEY, block blob)" - ) - - # Headers - await self.db.execute( - "CREATE TABLE IF NOT EXISTS headers(height bigint, header_hash " - "text PRIMARY KEY, proof_hash text, header blob)" - ) - - # LCA - await self.db.execute( - "CREATE TABLE IF NOT EXISTS lca(header_hash text PRIMARY KEY)" - ) - - # Height index so we can look up in order of height for sync purposes - await self.db.execute( - "CREATE INDEX IF NOT EXISTS block_height on blocks(height)" - ) - await self.db.execute( - "CREATE INDEX IF NOT EXISTS header_height on headers(height)" - ) - - await self.db.commit() - - self.sync_mode = False - self.waiting_for_tips = True - self.potential_tips = {} - self.potential_hashes = [] - self.potential_headers = {} - self.potential_hashes_received = None - self.potential_headers_received = {} - self.potential_blocks_received = {} - self.potential_future_blocks = [] - self.proof_of_time_estimate_ips = uint64(10000) - self.proof_of_time_heights = {} - self.unfinished_blocks_leader = ( - uint32(0), - uint64((1 << 64) - 1), - ) - self.candidate_blocks = {} - self.unfinished_blocks = {} - self.seen_unfinished_blocks = set() - self.disconnected_blocks = {} - self.lock = asyncio.Lock() # external - return self - - async def _clear_database(self): - async with self.lock: - await self.db.execute("DELETE FROM blocks") - await self.db.execute("DELETE FROM potential_blocks") - await self.db.execute("DELETE FROM headers") - await self.db.commit() - - async def get_lca(self) -> Optional[bytes32]: - cursor = await self.db.execute("SELECT * from lca") - row = await cursor.fetchone() - await cursor.close() - if row is not None: - return bytes32(bytes.fromhex(row[0])) - return None - - async def set_lca(self, header_hash: bytes32) -> None: - await self.db.execute("DELETE FROM lca") - cursor_1 = await self.db.execute( - "INSERT OR REPLACE INTO lca VALUES(?)", (header_hash.hex(),) - ) - await cursor_1.close() - await self.db.commit() - - async def add_block(self, block: FullBlock) -> None: - assert block.proof_of_time is not None - cursor_1 = await self.db.execute( - "INSERT OR REPLACE INTO blocks VALUES(?, ?, ?)", - (block.height, block.header_hash.hex(), bytes(block)), - ) - await cursor_1.close() - proof_hash = std_hash( - block.proof_of_space.get_hash() + block.proof_of_time.output.get_hash() - ) - cursor_2 = await self.db.execute( - ("INSERT OR REPLACE INTO headers VALUES(?, ?, ?, ?)"), - ( - block.height, - block.header_hash.hex(), - proof_hash.hex(), - bytes(block.header), - ), - ) - await cursor_2.close() - await self.db.commit() - - async def get_block(self, header_hash: bytes32) -> Optional[FullBlock]: - cursor = await self.db.execute( - "SELECT * from blocks WHERE header_hash=?", (header_hash.hex(),) - ) - row = await cursor.fetchone() - await cursor.close() - if row is not None: - return FullBlock.from_bytes(row[2]) - return None - - async def get_blocks_at(self, heights: List[uint32]) -> List[FullBlock]: - if len(heights) == 0: - return [] - - heights_db = tuple(heights) - formatted_str = ( - f'SELECT * from blocks WHERE height in ({"?," * (len(heights_db) - 1)}?)' - ) - cursor = await self.db.execute(formatted_str, heights_db) - rows = await cursor.fetchall() - await cursor.close() - blocks: List[FullBlock] = [] - for row in rows: - blocks.append(FullBlock.from_bytes(row[2])) - return blocks - - async def get_headers(self) -> List[Header]: - cursor = await self.db.execute("SELECT * from headers") - rows = await cursor.fetchall() - await cursor.close() - return [Header.from_bytes(row[3]) for row in rows] - - async def get_proof_hashes(self) -> Dict[bytes32, bytes32]: - cursor = await self.db.execute("SELECT header_hash, proof_hash from headers") - rows = await cursor.fetchall() - await cursor.close() - return {bytes.fromhex(row[0]): bytes.fromhex(row[1]) for row in rows} - - async def add_potential_block(self, block: FullBlock) -> None: - cursor = await self.db.execute( - "INSERT OR REPLACE INTO potential_blocks VALUES(?, ?)", - (block.height, bytes(block)), - ) - await cursor.close() - await self.db.commit() - - async def get_potential_block(self, height: uint32) -> Optional[FullBlock]: - cursor = await self.db.execute( - "SELECT * from potential_blocks WHERE height=?", (height,) - ) - row = await cursor.fetchone() - await cursor.close() - if row is not None: - return FullBlock.from_bytes(row[1]) - return None - - def add_disconnected_block(self, block: FullBlock) -> None: - self.disconnected_blocks[block.header_hash] = block - - def get_disconnected_block_by_prev( - self, prev_header_hash: bytes32 - ) -> Optional[FullBlock]: - for _, block in self.disconnected_blocks.items(): - if block.prev_header_hash == prev_header_hash: - return block - return None - - def get_disconnected_block(self, header_hash: bytes32) -> Optional[FullBlock]: - return self.disconnected_blocks.get(header_hash, None) - - def clear_disconnected_blocks_below(self, height: uint32) -> None: - for key in list(self.disconnected_blocks.keys()): - if self.disconnected_blocks[key].height < height: - del self.disconnected_blocks[key] - - def set_sync_mode(self, sync_mode: bool) -> None: - self.sync_mode = sync_mode - - def get_sync_mode(self) -> bool: - return self.sync_mode - - def set_waiting_for_tips(self, waiting_for_tips: bool) -> None: - self.waiting_for_tips = waiting_for_tips - - def get_waiting_for_tips(self) -> bool: - return self.waiting_for_tips - - async def clear_sync_info(self): - self.potential_tips.clear() - self.potential_headers.clear() - cursor = await self.db.execute("DELETE FROM potential_blocks") - await cursor.close() - self.potential_blocks_received.clear() - self.potential_future_blocks.clear() - self.waiting_for_tips = True - - def get_potential_tips_tuples(self) -> List[Tuple[bytes32, FullBlock]]: - return list(self.potential_tips.items()) - - def add_potential_tip(self, block: FullBlock) -> None: - self.potential_tips[block.header_hash] = block - - def get_potential_tip(self, header_hash: bytes32) -> Optional[FullBlock]: - return self.potential_tips.get(header_hash, None) - - def add_potential_header(self, block: HeaderBlock) -> None: - self.potential_headers[block.height] = block - - def get_potential_header(self, height: uint32) -> Optional[HeaderBlock]: - return self.potential_headers.get(height, None) - - def clear_potential_headers(self) -> None: - self.potential_headers.clear() - - def set_potential_hashes(self, potential_hashes: List[bytes32]) -> None: - self.potential_hashes = potential_hashes - - def get_potential_hashes(self) -> List[bytes32]: - return self.potential_hashes - - def set_potential_hashes_received(self, event: asyncio.Event): - self.potential_hashes_received = event - - def get_potential_hashes_received(self) -> Optional[asyncio.Event]: - return self.potential_hashes_received - - def set_potential_headers_received(self, height: uint32, event: asyncio.Event): - self.potential_headers_received[height] = event - - def get_potential_headers_received(self, height: uint32) -> asyncio.Event: - return self.potential_headers_received[height] - - def set_potential_blocks_received(self, height: uint32, event: asyncio.Event): - self.potential_blocks_received[height] = event - - def get_potential_blocks_received(self, height: uint32) -> asyncio.Event: - return self.potential_blocks_received[height] - - def add_potential_future_block(self, block: FullBlock): - self.potential_future_blocks.append(block) - - def get_potential_future_blocks(self): - return self.potential_future_blocks - - def add_candidate_block( - self, - pos_hash: bytes32, - transactions_generator: Optional[Program], - transactions_filter: Optional[bytes], - header: HeaderData, - pos: ProofOfSpace, - height: uint32 = uint32(0), - ): - self.candidate_blocks[pos_hash] = ( - transactions_generator, - transactions_filter, - header, - pos, - height, - ) - - def get_candidate_block( - self, pos_hash: bytes32 - ) -> Optional[Tuple[Optional[Program], Optional[bytes], HeaderData, ProofOfSpace]]: - res = self.candidate_blocks.get(pos_hash, None) - if res is None: - return None - return (res[0], res[1], res[2], res[3]) - - def clear_candidate_blocks_below(self, height: uint32) -> None: - del_keys = [] - for key, value in self.candidate_blocks.items(): - if value[4] < height: - del_keys.append(key) - for key in del_keys: - try: - del self.candidate_blocks[key] - except KeyError: - pass - - def add_unfinished_block( - self, key: Tuple[bytes32, uint64], block: FullBlock - ) -> None: - self.unfinished_blocks[key] = block - - def get_unfinished_block(self, key: Tuple[bytes32, uint64]) -> Optional[FullBlock]: - return self.unfinished_blocks.get(key, None) - - def seen_unfinished_block(self, header_hash: bytes32) -> bool: - if header_hash in self.seen_unfinished_blocks: - return True - self.seen_unfinished_blocks.add(header_hash) - return False - - def clear_seen_unfinished_blocks(self) -> None: - self.seen_unfinished_blocks.clear() - - def get_unfinished_blocks(self) -> Dict[Tuple[bytes32, uint64], FullBlock]: - return self.unfinished_blocks.copy() - - def clear_unfinished_blocks_below(self, height: uint32) -> None: - del_keys = [] - for key, unf in self.unfinished_blocks.items(): - if unf.height < height: - del_keys.append(key) - for key in del_keys: - try: - del self.unfinished_blocks[key] - except KeyError: - pass - - def set_unfinished_block_leader(self, key: Tuple[bytes32, uint64]) -> None: - self.unfinished_blocks_leader = key - - def get_unfinished_block_leader(self) -> Tuple[bytes32, uint64]: - return self.unfinished_blocks_leader - - def set_proof_of_time_estimate_ips(self, estimate: uint64): - self.proof_of_time_estimate_ips = estimate - - def get_proof_of_time_estimate_ips(self) -> uint64: - return self.proof_of_time_estimate_ips - - def add_proof_of_time_heights( - self, challenge_iters: Tuple[bytes32, uint64], height: uint32 - ) -> None: - self.proof_of_time_heights[challenge_iters] = height - - def get_proof_of_time_heights( - self, challenge_iters: Tuple[bytes32, uint64] - ) -> Optional[uint32]: - return self.proof_of_time_heights.get(challenge_iters, None) - - def clear_proof_of_time_heights_below(self, height: uint32) -> None: - del_keys: List = [] - for key, value in self.proof_of_time_heights.items(): - if value < height: - del_keys.append(key) - for key in del_keys: - try: - del self.proof_of_time_heights[key] - except KeyError: - pass diff --git a/src/full_node/sync_store.py b/src/full_node/sync_store.py new file mode 100644 index 00000000..e91e1532 --- /dev/null +++ b/src/full_node/sync_store.py @@ -0,0 +1,145 @@ +import asyncio +import logging +import aiosqlite +from typing import Dict, List, Optional, Tuple + +from src.types.program import Program +from src.types.full_block import FullBlock +from src.types.header import HeaderData, Header +from src.types.header_block import HeaderBlock +from src.types.proof_of_space import ProofOfSpace +from src.types.sized_bytes import bytes32 +from src.util.hash import std_hash +from src.util.ints import uint32, uint64 + +log = logging.getLogger(__name__) + + +class SyncStore: + db: aiosqlite.Connection + # Whether we are waiting for tips (at the start of sync) or already syncing + waiting_for_tips: bool + # Potential new tips that we have received from others. + potential_tips: Dict[bytes32, FullBlock] + # List of all header hashes up to the tip, download up front + potential_hashes: List[bytes32] + # Header blocks received from other peers during sync + potential_headers: Dict[uint32, HeaderBlock] + # Event to signal when header hashes are received + potential_hashes_received: Optional[asyncio.Event] + # Event to signal when headers are received at each height + potential_headers_received: Dict[uint32, asyncio.Event] + # Event to signal when blocks are received at each height + potential_blocks_received: Dict[uint32, asyncio.Event] + # Blocks that we have finalized during sync, queue them up for adding after sync is done + potential_future_blocks: List[FullBlock] + + @classmethod + async def create(cls, connection): + self = cls() + + # All full blocks which have been added to the blockchain. Header_hash -> block + self.db = connection + # Blocks received from other peers during sync, cleared after sync + await self.db.execute( + "CREATE TABLE IF NOT EXISTS potential_blocks(height bigint PRIMARY KEY, block blob)" + ) + + await self.db.commit() + + self.sync_mode = False + self.waiting_for_tips = True + self.potential_tips = {} + self.potential_hashes = [] + self.potential_headers = {} + self.potential_hashes_received = None + self.potential_headers_received = {} + self.potential_blocks_received = {} + self.potential_future_blocks = [] + return self + + async def _clear_database(self): + async with self.lock: + await self.db.execute("DELETE FROM potential_blocks") + await self.db.commit() + + async def add_potential_block(self, block: FullBlock) -> None: + cursor = await self.db.execute( + "INSERT OR REPLACE INTO potential_blocks VALUES(?, ?)", + (block.height, bytes(block)), + ) + await cursor.close() + await self.db.commit() + + async def get_potential_block(self, height: uint32) -> Optional[FullBlock]: + cursor = await self.db.execute( + "SELECT * from potential_blocks WHERE height=?", (height,) + ) + row = await cursor.fetchone() + await cursor.close() + if row is not None: + return FullBlock.from_bytes(row[1]) + return None + + def set_waiting_for_tips(self, waiting_for_tips: bool) -> None: + self.waiting_for_tips = waiting_for_tips + + def get_waiting_for_tips(self) -> bool: + return self.waiting_for_tips + + async def clear_sync_info(self): + self.potential_tips.clear() + self.potential_headers.clear() + cursor = await self.db.execute("DELETE FROM potential_blocks") + await cursor.close() + self.potential_blocks_received.clear() + self.potential_future_blocks.clear() + self.waiting_for_tips = True + + def get_potential_tips_tuples(self) -> List[Tuple[bytes32, FullBlock]]: + return list(self.potential_tips.items()) + + def add_potential_tip(self, block: FullBlock) -> None: + self.potential_tips[block.header_hash] = block + + def get_potential_tip(self, header_hash: bytes32) -> Optional[FullBlock]: + return self.potential_tips.get(header_hash, None) + + def add_potential_header(self, block: HeaderBlock) -> None: + self.potential_headers[block.height] = block + + def get_potential_header(self, height: uint32) -> Optional[HeaderBlock]: + return self.potential_headers.get(height, None) + + def clear_potential_headers(self) -> None: + self.potential_headers.clear() + + def set_potential_hashes(self, potential_hashes: List[bytes32]) -> None: + self.potential_hashes = potential_hashes + + def get_potential_hashes(self) -> List[bytes32]: + return self.potential_hashes + + def set_potential_hashes_received(self, event: asyncio.Event): + self.potential_hashes_received = event + + def get_potential_hashes_received(self) -> Optional[asyncio.Event]: + return self.potential_hashes_received + + def set_potential_headers_received(self, height: uint32, event: asyncio.Event): + self.potential_headers_received[height] = event + + def get_potential_headers_received(self, height: uint32) -> asyncio.Event: + return self.potential_headers_received[height] + + def set_potential_blocks_received(self, height: uint32, event: asyncio.Event): + self.potential_blocks_received[height] = event + + def get_potential_blocks_received(self, height: uint32) -> asyncio.Event: + return self.potential_blocks_received[height] + + def add_potential_future_block(self, block: FullBlock): + self.potential_future_blocks.append(block) + + def get_potential_future_blocks(self): + return self.potential_future_blocks diff --git a/src/rpc/rpc_server.py b/src/rpc/rpc_server.py index f210889e..9fa4c94d 100644 --- a/src/rpc/rpc_server.py +++ b/src/rpc/rpc_server.py @@ -53,11 +53,11 @@ class RpcApiHandler: """ tips: List[Header] = self.full_node.blockchain.get_current_tips() lca: Header = self.full_node.blockchain.lca_block - sync_mode: bool = self.full_node.store.get_sync_mode() + sync_mode: bool = self.full_node.full_node_store.get_sync_mode() difficulty: uint64 = self.full_node.blockchain.get_next_difficulty( lca.header_hash ) - lca_block = await self.full_node.store.get_block(lca.header_hash) + lca_block = await self.full_node.block_store.get_block(lca.header_hash) if lca_block is None: raise web.HTTPNotFound() min_iters: uint64 = self.full_node.blockchain.get_next_min_iters(lca_block) @@ -90,7 +90,7 @@ class RpcApiHandler: raise web.HTTPBadRequest() header_hash = hexstr_to_bytes(request_data["header_hash"]) - block: Optional[FullBlock] = await self.full_node.store.get_block(header_hash) + block: Optional[FullBlock] = await self.full_node.block_store.get_block(header_hash) if block is None: raise web.HTTPNotFound() return obj_to_response(block) @@ -132,7 +132,7 @@ class RpcApiHandler: raise web.HTTPBadRequest() height = request_data["height"] response_headers: List[Header] = [] - for block in (self.full_node.store.get_unfinished_blocks()).values(): + for block in (self.full_node.full_node_store.get_unfinished_blocks()).values(): if block.height == height: response_headers.append(block.header) @@ -218,7 +218,7 @@ class RpcApiHandler: else: header = None - coin_records = await self.full_node.blockchain.unspent_store.get_coin_records_by_puzzle_hash( + coin_records = await self.full_node.blockchain.coin_store.get_coin_records_by_puzzle_hash( puzzle_hash, header ) @@ -232,8 +232,8 @@ class RpcApiHandler: tip_weights = [tip.weight for tip in tips] i = tip_weights.index(max(tip_weights)) max_tip: Header = tips[i] - if self.full_node.store.get_sync_mode(): - potential_tips = self.full_node.store.get_potential_tips_tuples() + if self.full_node.full_node_store.get_sync_mode(): + potential_tips = self.full_node.sync_store.get_potential_tips_tuples() for _, pot_block in potential_tips: if pot_block.weight > max_tip.weight: max_tip = pot_block.header diff --git a/src/server/start_full_node.py b/src/server/start_full_node.py index ba1083d8..6281848a 100644 --- a/src/server/start_full_node.py +++ b/src/server/start_full_node.py @@ -38,25 +38,7 @@ async def async_main(): log = logging.getLogger(__name__) server_closed = False - db_path = path_from_root(root_path, config["database_path"]) - mkdir(db_path.parent) - - # Create the store (DB) and full node instance - connection = await aiosqlite.connect(db_path) - store = await FullNodeStore.create(connection) - - genesis: FullBlock = FullBlock.from_bytes(constants["GENESIS_BLOCK"]) - await store.add_block(genesis) - unspent_store = await CoinStore.create(connection) - - log.info("Initializing blockchain from disk") - blockchain = await Blockchain.create(unspent_store, store) - log.info("Blockchain initialized") - - mempool_manager = MempoolManager(unspent_store) - await mempool_manager.new_tips(await blockchain.get_full_tips()) - - full_node = FullNode(store, blockchain, config, mempool_manager, unspent_store) + full_node = await FullNode.create(config) if config["enable_upnp"]: log.info(f"Attempting to enable UPnP (open up port {config['port']})") @@ -90,12 +72,12 @@ async def async_main(): _ = await server.start_server(full_node._on_connect) rpc_cleanup = None - def master_close_cb(): + async def master_close_cb(): nonlocal server_closed if not server_closed: # Called by the UI, when node is closed, or when a signal is sent log.info("Closing all connections, and server...") - full_node._shutdown() + await full_node._shutdown() server.close_all() server_closed = True @@ -122,9 +104,6 @@ async def async_main(): await rpc_cleanup() log.info("Closed RPC server.") - await connection.close() - log.info("Closed db connection.") - await asyncio.get_running_loop().shutdown_asyncgens() log.info("Node fully closed.") diff --git a/src/simulator/full_node_simulator.py b/src/simulator/full_node_simulator.py index 45b8a74d..b80aaefa 100644 --- a/src/simulator/full_node_simulator.py +++ b/src/simulator/full_node_simulator.py @@ -27,26 +27,6 @@ bt = BlockTools() class FullNodeSimulator(FullNode): - def __init__( - self, - store: FullNodeStore, - blockchain: Blockchain, - config: Dict, - mempool_manager: MempoolManager, - coin_store: CoinStore, - name: str = None, - override_constants=None, - ): - super().__init__( - store, - blockchain, - config, - mempool_manager, - coin_store, - name, - override_constants, - ) - def _set_server(self, server: ChiaServer): super()._set_server(server) @@ -128,7 +108,7 @@ class FullNodeSimulator(FullNode): if tip_hash == self.blockchain.genesis.header_hash: current_blocks.append(self.blockchain.genesis) break - full = await self.store.get_block(tip_hash) + full = await self.block_store.get_block(tip_hash) if full is None: break current_blocks.append(full) diff --git a/src/simulator/start_simulator.py b/src/simulator/start_simulator.py index c65259d9..dd38cf66 100644 --- a/src/simulator/start_simulator.py +++ b/src/simulator/start_simulator.py @@ -46,20 +46,16 @@ async def main(): genesis: FullBlock = FullBlock.from_bytes(test_constants["GENESIS_BLOCK"]) await store.add_block(genesis) - unspent_store = await CoinStore.create(connection) + coin_store = await CoinStore.create(connection) log.info("Initializing blockchain from disk") - blockchain = await Blockchain.create(unspent_store, store, test_constants) + blockchain = await Blockchain.create(coin_store, store, test_constants) - mempool_manager = MempoolManager(unspent_store, test_constants) + mempool_manager = MempoolManager(coin_store, test_constants) await mempool_manager.new_tips(await blockchain.get_full_tips()) - full_node = FullNodeSimulator( - store, - blockchain, + full_node = await FullNodeSimulator.create( config, - mempool_manager, - unspent_store, override_constants=test_constants, ) @@ -119,7 +115,7 @@ async def main(): await store.close() log.info("Closed store.") - await unspent_store.close() + await coin_store.close() log.info("Closed unspent store.") await asyncio.get_running_loop().shutdown_asyncgens() diff --git a/src/wallet/wallet_puzzle_store.py b/src/wallet/wallet_puzzle_store.py index dd7dc697..f0fde7b7 100644 --- a/src/wallet/wallet_puzzle_store.py +++ b/src/wallet/wallet_puzzle_store.py @@ -58,7 +58,7 @@ class WalletPuzzleStore: ) await self.db_connection.execute( - "CREATE INDEX IF NOT EXISTS wallet_ud on derivation_paths(wallet_id)" + "CREATE INDEX IF NOT EXISTS wallet_id on derivation_paths(wallet_id)" ) await self.db_connection.execute( diff --git a/tests/full_node/test_block_store.py b/tests/full_node/test_block_store.py new file mode 100644 index 00000000..b42b459b --- /dev/null +++ b/tests/full_node/test_block_store.py @@ -0,0 +1,104 @@ +import asyncio +from secrets import token_bytes +from pathlib import Path +from typing import Any, Dict +import sqlite3 +import random + +import aiosqlite +import pytest +from src.full_node.block_store import BlockStore +from src.full_node.coin_store import CoinStore +from src.full_node.blockchain import Blockchain +from src.types.full_block import FullBlock +from src.types.sized_bytes import bytes32 +from src.util.ints import uint32, uint64 +from tests.block_tools import BlockTools + +bt = BlockTools() + +test_constants: Dict[str, Any] = { + "DIFFICULTY_STARTING": 5, + "DISCRIMINANT_SIZE_BITS": 16, + "BLOCK_TIME_TARGET": 10, + "MIN_BLOCK_TIME": 2, + "DIFFICULTY_EPOCH": 12, # The number of blocks per epoch + "DIFFICULTY_DELAY": 3, # EPOCH / WARP_FACTOR +} +test_constants["GENESIS_BLOCK"] = bytes( + bt.create_genesis_block(test_constants, bytes([0] * 32), b"0") +) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + + +class TestBlockStore: + @pytest.mark.asyncio + async def test_block_store(self): + assert sqlite3.threadsafety == 1 + blocks = bt.get_consecutive_blocks(test_constants, 9, [], 9, b"0") + blocks_alt = bt.get_consecutive_blocks(test_constants, 3, [], 9, b"1") + db_filename = Path("blockchain_test.db") + db_filename_2 = Path("blockchain_test_2.db") + db_filename_3 = Path("blockchain_test_3.db") + + if db_filename.exists(): + db_filename.unlink() + if db_filename_2.exists(): + db_filename_2.unlink() + if db_filename_3.exists(): + db_filename_3.unlink() + + connection = await aiosqlite.connect(db_filename) + connection_2 = await aiosqlite.connect(db_filename_2) + connection_3 = await aiosqlite.connect(db_filename_3) + + db = await BlockStore.create(connection) + db_2 = await BlockStore.create(connection_2) + try: + await db._clear_database() + + genesis = FullBlock.from_bytes(test_constants["GENESIS_BLOCK"]) + + # Save/get block + for block in blocks: + await db.add_block(block) + assert block == await db.get_block(block.header_hash) + + await db.add_block(blocks_alt[2]) + assert len(await db.get_blocks_at([1, 2])) == 3 + + # Get headers (added alt block also, so +1) + assert len(await db.get_headers()) == len(blocks) + 1 + + # Test LCA + assert (await db.get_lca()) is None + + unspent_store = await CoinStore.create(connection) + b: Blockchain = await Blockchain.create(unspent_store, db, test_constants) + + assert (await db.get_lca()) == blocks[-3].header_hash + assert b.lca_block.header_hash == (await db.get_lca()) + + b_2: Blockchain = await Blockchain.create(unspent_store, db, test_constants) + assert (await db.get_lca()) == blocks[-3].header_hash + assert b_2.lca_block.header_hash == (await db.get_lca()) + + except Exception: + await connection.close() + await connection_2.close() + await connection_3.close() + db_filename.unlink() + db_filename_2.unlink() + raise + + await connection.close() + await connection_2.close() + await connection_3.close() + db_filename.unlink() + db_filename_2.unlink() + db_filename_3.unlink() diff --git a/tests/setup_nodes.py b/tests/setup_nodes.py index a578cc0f..bcc6f0d7 100644 --- a/tests/setup_nodes.py +++ b/tests/setup_nodes.py @@ -63,20 +63,6 @@ async def setup_full_node_simulator(db_name, port, introducer_port=None, dic={}) if db_path.exists(): db_path.unlink() - connection = await aiosqlite.connect(db_path) - store_1 = await FullNodeStore.create(connection) - await store_1._clear_database() - unspent_store_1 = await CoinStore.create(connection) - await unspent_store_1._clear_database() - mempool_1 = MempoolManager(unspent_store_1, test_constants_copy) - - b_1: Blockchain = await Blockchain.create( - unspent_store_1, store_1, test_constants_copy - ) - await mempool_1.new_tips(await b_1.get_full_tips()) - - await store_1.add_block(FullBlock.from_bytes(test_constants_copy["GENESIS_BLOCK"])) - net_config = load_config(root_path, "config.yaml") ping_interval = net_config.get("ping_interval") network_id = net_config.get("network_id") @@ -87,12 +73,8 @@ async def setup_full_node_simulator(db_name, port, introducer_port=None, dic={}) if introducer_port is not None: config["introducer_peer"]["host"] = "127.0.0.1" config["introducer_peer"]["port"] = introducer_port - full_node_1 = FullNodeSimulator( - store_1, - b_1, + full_node_1 = await FullNodeSimulator.create( config, - mempool_1, - unspent_store_1, f"full_node_{port}", test_constants_copy, ) @@ -114,10 +96,9 @@ async def setup_full_node_simulator(db_name, port, introducer_port=None, dic={}) yield (full_node_1, server_1) # TEARDOWN - full_node_1._shutdown() + await full_node_1._shutdown() server_1.close_all() await server_1.await_closed() - await connection.close() db_path.unlink() @@ -127,35 +108,19 @@ async def setup_full_node(db_name, port, introducer_port=None, dic={}): for k in dic.keys(): test_constants_copy[k] = dic[k] - db_path = Path(db_name) - connection = await aiosqlite.connect(db_path) - store_1 = await FullNodeStore.create(connection) - await store_1._clear_database() - unspent_store_1 = await CoinStore.create(connection) - await unspent_store_1._clear_database() - mempool_1 = MempoolManager(unspent_store_1, test_constants_copy) - - b_1: Blockchain = await Blockchain.create( - unspent_store_1, store_1, test_constants_copy - ) - await mempool_1.new_tips(await b_1.get_full_tips()) - - await store_1.add_block(FullBlock.from_bytes(test_constants_copy["GENESIS_BLOCK"])) + Path(db_name).unlink() net_config = load_config(root_path, "config.yaml") ping_interval = net_config.get("ping_interval") network_id = net_config.get("network_id") config = load_config(root_path, "config.yaml", "full_node") + config["database_path"] = db_name if introducer_port is not None: config["introducer_peer"]["host"] = "127.0.0.1" config["introducer_peer"]["port"] = introducer_port - full_node_1 = FullNode( - store_1, - b_1, + full_node_1 = await FullNode.create( config, - mempool_1, - unspent_store_1, f"full_node_{port}", test_constants_copy, )