Almost have sync working
This commit is contained in:
parent
3f0c59336f
commit
284faa46e8
|
@ -4,7 +4,7 @@ ping_interval: 120
|
|||
|
||||
# Controls logging of all servers (harvester, farmer, etc..). Each one can be overriden.
|
||||
logging: &logging
|
||||
log_stdout: False # If True, outputs to stdout instead of a file
|
||||
log_stdout: True # If True, outputs to stdout instead of a file
|
||||
log_filename: "chia.log"
|
||||
|
||||
harvester:
|
||||
|
@ -124,9 +124,14 @@ introducer:
|
|||
|
||||
wallet:
|
||||
host: 127.0.0.1
|
||||
port: 8223
|
||||
port: 8449
|
||||
rpc_port: 9256
|
||||
|
||||
# The minimum height that we care about for our transactions. Set to zero
|
||||
# If we are restoring from private key and don't know the height.
|
||||
starting_height: 0
|
||||
num_sync_batches: 10
|
||||
|
||||
full_node_peer:
|
||||
host: 127.0.0.1
|
||||
port: 8444
|
||||
|
|
|
@ -1821,7 +1821,7 @@ class FullNode:
|
|||
)
|
||||
yield OutboundMessage(
|
||||
NodeType.WALLET,
|
||||
Message("respond_block", response),
|
||||
Message("respond_header", response),
|
||||
Delivery.RESPOND,
|
||||
)
|
||||
return
|
||||
|
|
|
@ -41,7 +41,7 @@ async def main():
|
|||
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all)
|
||||
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all)
|
||||
|
||||
_ = await server.start_server(config["host"], wallet._on_connect)
|
||||
_ = await server.start_server(config["host"], None)
|
||||
await asyncio.sleep(1)
|
||||
_ = await server.start_client(full_node_peer, None)
|
||||
|
||||
|
|
|
@ -1,17 +1,28 @@
|
|||
from pathlib import Path
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from blspy import ExtendedPrivateKey
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
import concurrent
|
||||
import logging
|
||||
from blspy import ExtendedPrivateKey
|
||||
from src.util.merkle_set import (
|
||||
confirm_included_already_hashed,
|
||||
confirm_not_included_already_hashed,
|
||||
)
|
||||
from src.protocols import wallet_protocol
|
||||
from src.consensus.constants import constants as consensus_constants
|
||||
from src.server.server import ChiaServer
|
||||
from src.server.outbound_message import OutboundMessage, NodeType, Message, Delivery
|
||||
from src.util.ints import uint32
|
||||
from src.util.ints import uint32, uint64
|
||||
from src.types.sized_bytes import bytes32
|
||||
from src.util.api_decorators import api_request
|
||||
from src.wallet.wallet import Wallet
|
||||
from src.wallet.wallet_state_manager import WalletStateManager
|
||||
from src.wallet.block_record import BlockRecord
|
||||
from src.types.header_block import HeaderBlock
|
||||
from src.types.full_block import FullBlock
|
||||
from src.types.hashable.coin import Coin, hash_coin_list
|
||||
from src.full_node.blockchain import ReceiveBlockResult
|
||||
|
||||
|
||||
class WalletNode:
|
||||
|
@ -22,12 +33,25 @@ class WalletNode:
|
|||
wallet_state_manager: WalletStateManager
|
||||
log: logging.Logger
|
||||
wallet: Wallet
|
||||
cached_blocks: Dict[bytes32, Tuple[BlockRecord, HeaderBlock]]
|
||||
cached_removals: Dict[bytes32, List[bytes32]]
|
||||
cached_additions: Dict[bytes32, List[Coin]]
|
||||
proof_hashes: List[Tuple[bytes32, Optional[uint64]]]
|
||||
header_hashes: List[bytes32]
|
||||
potential_blocks_received: Dict[uint32, asyncio.Event]
|
||||
potential_header_hashes: Dict[uint32, bytes32]
|
||||
constants: Dict
|
||||
short_sync_threshold: int
|
||||
sync_mode: bool
|
||||
_shut_down: bool
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
config: Dict, key_config: Dict, name: str = None, override_constants: Dict = {}
|
||||
config: Dict,
|
||||
key_config: Dict,
|
||||
name: str = None,
|
||||
db_path=None,
|
||||
override_constants: Dict = {},
|
||||
):
|
||||
self = WalletNode()
|
||||
self.config = config
|
||||
|
@ -43,15 +67,31 @@ class WalletNode:
|
|||
self.log = logging.getLogger(__name__)
|
||||
|
||||
pub_hex = self.private_key.get_public_key().serialize().hex()
|
||||
path = Path(f"wallet_db_{pub_hex}.db")
|
||||
if not db_path:
|
||||
path = Path(f"wallet_db_{pub_hex}.db")
|
||||
else:
|
||||
path = db_path
|
||||
|
||||
self.wallet_state_manager = await WalletStateManager.create(
|
||||
config, path, override_constants=override_constants,
|
||||
config, path, override_constants=self.constants,
|
||||
)
|
||||
self.wallet = await Wallet.create(config, key_config, self.wallet_state_manager)
|
||||
|
||||
self.server = None
|
||||
# Normal operation data
|
||||
self.cached_blocks = {}
|
||||
self.cached_removals = {}
|
||||
self.cached_additions = {}
|
||||
|
||||
# Sync data
|
||||
self.sync_mode = False
|
||||
self._shut_down = False
|
||||
self.proof_hashes = []
|
||||
self.header_hashes = []
|
||||
self.short_sync_threshold = 10
|
||||
self.potential_blocks_received = {}
|
||||
self.potential_header_hashes = {}
|
||||
|
||||
self.server = None
|
||||
|
||||
return self
|
||||
|
||||
|
@ -59,13 +99,262 @@ class WalletNode:
|
|||
self.server = server
|
||||
self.wallet.set_server(server)
|
||||
|
||||
def _shutdown(self):
|
||||
self._shut_down = True
|
||||
|
||||
async def _sync(self):
|
||||
"""
|
||||
Wallet has fallen far behind (or is starting up for the first time), and must be synced
|
||||
up to the tip of the blockchain
|
||||
"""
|
||||
# TODO(mariano): implement
|
||||
pass
|
||||
# 1. Get all header hashes
|
||||
self.header_hashes = []
|
||||
self.proof_hashes = []
|
||||
self.potential_header_hashes = {}
|
||||
genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"])
|
||||
genesis_challenge = genesis.proof_of_space.challenge_hash
|
||||
request_header_hashes = wallet_protocol.RequestAllHeaderHashesAfter(
|
||||
uint32(0), genesis_challenge
|
||||
)
|
||||
yield OutboundMessage(
|
||||
NodeType.FULL_NODE,
|
||||
Message("request_all_header_hashes_after", request_header_hashes),
|
||||
Delivery.RESPOND,
|
||||
)
|
||||
timeout = 100
|
||||
start_wait = time.time()
|
||||
while time.time() - start_wait < timeout:
|
||||
if self._shut_down:
|
||||
return
|
||||
if len(self.header_hashes) > 0:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
if len(self.header_hashes) == 0:
|
||||
raise TimeoutError("Took too long to fetch header hashes.")
|
||||
|
||||
# 2. Find fork point
|
||||
fork_point_height: uint32 = self.wallet_state_manager.find_fork_point(
|
||||
self.header_hashes
|
||||
)
|
||||
fork_point_hash: bytes32 = self.header_hashes[fork_point_height]
|
||||
self.log.info(f"Fork point: {fork_point_hash} at height {fork_point_height}")
|
||||
tip_height = (
|
||||
len(self.header_hashes) - 5
|
||||
if len(self.header_hashes) > 5
|
||||
else len(self.header_hashes)
|
||||
)
|
||||
|
||||
header_validate_start_height: uint32
|
||||
if self.config["starting_height"] == 0:
|
||||
header_validate_start_height = fork_point_height
|
||||
else:
|
||||
# Request all proof hashes
|
||||
request_proof_hashes = wallet_protocol.RequestAllProofHashes()
|
||||
yield OutboundMessage(
|
||||
NodeType.FULL_NODE,
|
||||
Message("request_all_proof_hashes", request_proof_hashes),
|
||||
Delivery.RESPOND,
|
||||
)
|
||||
start_wait = time.time()
|
||||
while time.time() - start_wait < timeout:
|
||||
if self._shut_down:
|
||||
return
|
||||
if len(self.proof_hashes) > 0:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
if len(self.proof_hashes) == 0:
|
||||
raise TimeoutError("Took too long to fetch proof hashes.")
|
||||
|
||||
# TODO(mariano): Validate weight
|
||||
# - Request headers for a random subset
|
||||
# - Verify those proofs
|
||||
|
||||
weight = self.wallet_state_manager.block_records[fork_point_hash].weight
|
||||
header_validate_start_height = max(
|
||||
fork_point_height, self.config["starting_height"] - 1
|
||||
)
|
||||
if fork_point_height == 0:
|
||||
difficulty = self.constants["STARTING_DIFFICULTY"]
|
||||
else:
|
||||
fork_point_parent_hash = self.wallet_state_manager.block_records[
|
||||
fork_point_hash
|
||||
].prev_header_hash
|
||||
fork_point_parent_weight = self.wallet_state_manager.block_records[
|
||||
fork_point_parent_hash
|
||||
]
|
||||
difficulty = uint64(weight - fork_point_parent_weight)
|
||||
for height in range(fork_point_height + 1, header_validate_start_height):
|
||||
_, difficulty_change = self.proof_hashes[height]
|
||||
if difficulty_change is not None:
|
||||
difficulty = difficulty_change
|
||||
weight += difficulty
|
||||
block_record = BlockRecord(
|
||||
self.header_hashes[height],
|
||||
self.header_hashes[height - 1],
|
||||
uint32(height),
|
||||
weight,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
res = await self.wallet_state_manager.receive_block(block_record, None)
|
||||
|
||||
# Download headers in batches, and verify them as they come in. We download a few batches ahead,
|
||||
# in case there are delays. TODO(mariano): optimize sync by pipelining
|
||||
for height in range(0, tip_height + 1):
|
||||
self.potential_blocks_received[uint32(height)] = asyncio.Event()
|
||||
last_request_time = float(0)
|
||||
highest_height_requested = uint32(0)
|
||||
request_made = False
|
||||
sleep_interval = 10
|
||||
|
||||
for height_checkpoint in range(
|
||||
header_validate_start_height + 1, tip_height + 1, 1
|
||||
):
|
||||
end_height = min(height_checkpoint + 1, tip_height + 1)
|
||||
|
||||
total_time_slept = 0
|
||||
while True:
|
||||
if self._shut_down:
|
||||
return
|
||||
if total_time_slept > timeout:
|
||||
raise TimeoutError("Took too long to fetch blocks")
|
||||
|
||||
# Request batches that we don't have yet
|
||||
for batch in range(0, self.config["num_sync_batches"]):
|
||||
batch_start = uint32(height_checkpoint + batch)
|
||||
batch_end = min(batch_start + 1, tip_height + 1)
|
||||
|
||||
if batch_start > tip_height:
|
||||
# We have asked for all blocks
|
||||
break
|
||||
|
||||
blocks_missing = any(
|
||||
[
|
||||
not (self.potential_blocks_received[uint32(h)]).is_set()
|
||||
for h in range(batch_start, batch_end)
|
||||
]
|
||||
)
|
||||
if (
|
||||
time.time() - last_request_time > sleep_interval
|
||||
and blocks_missing
|
||||
) or (batch_end - 1) > highest_height_requested:
|
||||
# If we are missing blocks in this batch, and we haven't made a request in a while,
|
||||
# Make a request for this batch. Also, if we have never requested this batch, make
|
||||
# the request
|
||||
self.log.info(f"Requesting sync header {batch_start}")
|
||||
if batch_end - 1 > highest_height_requested:
|
||||
highest_height_requested = uint32(batch_end - 1)
|
||||
request_made = True
|
||||
request_header = wallet_protocol.RequestHeader(
|
||||
batch_start, self.header_hashes[batch_start],
|
||||
)
|
||||
yield OutboundMessage(
|
||||
NodeType.FULL_NODE,
|
||||
Message("request_header", request_header),
|
||||
Delivery.RANDOM,
|
||||
)
|
||||
if request_made:
|
||||
# Reset the timer for requests, so we don't overload other peers with requests
|
||||
last_request_time = time.time()
|
||||
request_made = False
|
||||
|
||||
awaitables = [
|
||||
self.potential_blocks_received[uint32(height)].wait()
|
||||
for height in range(height_checkpoint, end_height)
|
||||
]
|
||||
future = asyncio.gather(*awaitables, return_exceptions=True)
|
||||
try:
|
||||
await asyncio.wait_for(future, timeout=sleep_interval)
|
||||
break
|
||||
except concurrent.futures.TimeoutError:
|
||||
try:
|
||||
await future
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
total_time_slept += sleep_interval
|
||||
self.log.info("Did not receive desired headers")
|
||||
|
||||
# Verifies this batch, which we are guaranteed to have (since we broke from the above loop)
|
||||
for height in range(height_checkpoint, end_height):
|
||||
hh = self.potential_header_hashes[height]
|
||||
block_record, header_block = self.cached_blocks[hh]
|
||||
|
||||
res = await self.wallet_state_manager.receive_block(
|
||||
block_record, header_block
|
||||
)
|
||||
if (
|
||||
res == ReceiveBlockResult.INVALID_BLOCK
|
||||
or res == ReceiveBlockResult.DISCONNECTED_BLOCK
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Invalid block header {block_record.header_hash}"
|
||||
)
|
||||
self.log.info(
|
||||
f"Finished sync process up to height {max(self.wallet_state_manager.height_to_hash.keys())}"
|
||||
)
|
||||
|
||||
async def _block_finished(
|
||||
self, block_record: BlockRecord, header_block: HeaderBlock
|
||||
):
|
||||
if self.sync_mode:
|
||||
self.potential_blocks_received[uint32(block_record.height)].set()
|
||||
self.potential_header_hashes[block_record.height] = block_record.header_hash
|
||||
self.cached_blocks[block_record.header_hash] = (block_record, header_block)
|
||||
return
|
||||
# 1. If disconnected and close, get parent header and return
|
||||
lca = self.wallet_state_manager.block_records[self.wallet_state_manager.lca]
|
||||
if block_record.prev_header_hash in self.wallet_state_manager.block_records:
|
||||
# We have completed a block that we can add to chain, so add it.
|
||||
res = await self.wallet_state_manager.receive_block(
|
||||
block_record, header_block
|
||||
)
|
||||
if res == ReceiveBlockResult.DISCONNECTED_BLOCK:
|
||||
self.log.error("Attempted to add disconnected block")
|
||||
return
|
||||
elif res == ReceiveBlockResult.INVALID_BLOCK:
|
||||
self.log.error("Attempted to add invalid block")
|
||||
return
|
||||
elif res == ReceiveBlockResult.ALREADY_HAVE_BLOCK:
|
||||
return
|
||||
else:
|
||||
# If we have the next block available, add it
|
||||
if block_record.header_hash in self.cached_blocks:
|
||||
new_br, new_hb = self.cached_blocks[block_record.header_hash]
|
||||
async for msg in self._block_finished(new_br, new_hb):
|
||||
yield msg
|
||||
if res == ReceiveBlockResult.ADDED_TO_HEAD:
|
||||
# Removes outdated cached blocks if we're not syncing
|
||||
if not self.sync_mode:
|
||||
for header_hash in self.cached_blocks:
|
||||
if (
|
||||
block_record.height
|
||||
- self.cached_blocks[header_hash][0].height
|
||||
> 100
|
||||
):
|
||||
del self.cached_blocks[header_hash]
|
||||
if header_hash in self.cached_additions:
|
||||
del self.cached_additions[header_hash]
|
||||
if header_hash in self.cached_removals:
|
||||
del self.cached_removals[header_hash]
|
||||
else:
|
||||
if block_record.height - lca.height < self.short_sync_threshold:
|
||||
# We have completed a block that is in the near future, so cache it, and fetch parent
|
||||
self.cached_blocks[block_record.prev_header_hash] = (
|
||||
block_record,
|
||||
header_block,
|
||||
)
|
||||
|
||||
header_request = wallet_protocol.RequestHeader(
|
||||
uint32(block_record.height - 1), block_record.prev_header_hash,
|
||||
)
|
||||
yield OutboundMessage(
|
||||
NodeType.FULL_NODE,
|
||||
Message("request_header", header_request),
|
||||
Delivery.RESPOND,
|
||||
)
|
||||
return
|
||||
self.log.warning("Block too far ahead in the future, should never get here")
|
||||
return
|
||||
|
||||
@api_request
|
||||
async def transaction_ack(self, ack: wallet_protocol.TransactionAck):
|
||||
|
@ -79,31 +368,37 @@ class WalletNode:
|
|||
async def respond_all_proof_hashes(
|
||||
self, response: wallet_protocol.RespondAllProofHashes
|
||||
):
|
||||
# TODO(mariano): save proof hashes
|
||||
pass
|
||||
if not self.sync_mode:
|
||||
self.log.warning("Receiving proof hashes while not syncing.")
|
||||
return
|
||||
self.proof_hashes = response.hashes
|
||||
|
||||
@api_request
|
||||
async def respond_all_header_hashes_after(
|
||||
self, response: wallet_protocol.RespondAllHeaderHashesAfter
|
||||
):
|
||||
# TODO(mariano): save header_hashes
|
||||
pass
|
||||
if not self.sync_mode:
|
||||
self.log.warning("Receiving header hashes while not syncing.")
|
||||
return
|
||||
self.header_hashes = response.hashes
|
||||
|
||||
@api_request
|
||||
async def reject_all_header_hashes_after_request(
|
||||
self, response: wallet_protocol.RejectAllHeaderHashesAfterRequest
|
||||
):
|
||||
# TODO(mariano): retry
|
||||
self.log.error("All header hashes after request rejected")
|
||||
pass
|
||||
|
||||
@api_request
|
||||
async def new_lca(self, request: wallet_protocol.NewLCA):
|
||||
if self.sync_mode:
|
||||
return
|
||||
# If already seen LCA, ignore.
|
||||
if request.lca_hash in self.wallet_state_manager.block_records:
|
||||
return
|
||||
|
||||
lca = self.wallet_state_manager.block_records[self.wallet_state_manager.lca]
|
||||
|
||||
# If it's not the heaviest chain, ignore.
|
||||
if request.weight < lca.weight:
|
||||
return
|
||||
|
@ -111,12 +406,14 @@ class WalletNode:
|
|||
if int(request.height) - int(lca.height) > self.short_sync_threshold:
|
||||
try:
|
||||
# Performs sync, and catch exceptions so we don't close the connection
|
||||
self.sync_mode = True
|
||||
async for ret_msg in self._sync():
|
||||
yield ret_msg
|
||||
except asyncio.CancelledError:
|
||||
self.log.error("Syncing failed, CancelledError")
|
||||
except BaseException as e:
|
||||
self.log.error(f"Error {type(e)}{e} with syncing")
|
||||
self.sync_mode = False
|
||||
else:
|
||||
header_request = wallet_protocol.RequestHeader(
|
||||
uint32(request.height - 1), request.prev_header_hash
|
||||
|
@ -130,27 +427,23 @@ class WalletNode:
|
|||
@api_request
|
||||
async def respond_header(self, response: wallet_protocol.RespondHeader):
|
||||
block = response.header_block
|
||||
# 0. If we already have, return
|
||||
# If we already have, return
|
||||
if block.header_hash in self.wallet_state_manager.block_records:
|
||||
return
|
||||
|
||||
lca = self.wallet_state_manager.block_records[self.wallet_state_manager.lca]
|
||||
block_record = BlockRecord(
|
||||
block.header_hash,
|
||||
block.prev_header_hash,
|
||||
block.height,
|
||||
block.weight,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
# 1. If disconnected and close, get parent header and return
|
||||
if block.prev_header_hash not in self.wallet_state_manager.block_records:
|
||||
if block.height - lca.height < self.short_sync_threshold:
|
||||
header_request = wallet_protocol.RequestHeader(
|
||||
uint32(block.height - 1), block.prev_header_hash,
|
||||
)
|
||||
yield OutboundMessage(
|
||||
NodeType.FULL_NODE,
|
||||
Message("request_header", header_request),
|
||||
Delivery.RESPOND,
|
||||
)
|
||||
return
|
||||
|
||||
# 2. If we have transactions, fetch adds/deletes
|
||||
# If we have transactions, fetch adds/deletes
|
||||
if response.transactions_filter is not None:
|
||||
# Caches the block so we can finalize it when additions and removals arrive
|
||||
self.cached_blocks[block.header_hash] = (block_record, block)
|
||||
(
|
||||
additions,
|
||||
removals,
|
||||
|
@ -178,90 +471,153 @@ class WalletNode:
|
|||
Delivery.RESPOND,
|
||||
)
|
||||
else:
|
||||
block_record = BlockRecord(block.header_hash, block.prev_header_hash, block.height, block.weight, [], [])
|
||||
res = await self.wallet_state_manager.receive_block(block_record, block)
|
||||
# 3. If we don't have, don't fetch
|
||||
# 4. If we have the next header cached, process it
|
||||
pass
|
||||
# If we don't have any transactions in filter, don't fetch, and finish the block
|
||||
async for msg in self._block_finished(block_record, block):
|
||||
yield msg
|
||||
|
||||
@api_request
|
||||
async def reject_header_request(
|
||||
self, response: wallet_protocol.RejectHeaderRequest
|
||||
):
|
||||
# TODO(mariano): implement
|
||||
pass
|
||||
self.log.error("Header request rejected")
|
||||
|
||||
@api_request
|
||||
async def respond_removals(self, response: wallet_protocol.RespondRemovals):
|
||||
# TODO(mariano): implement
|
||||
pass
|
||||
if response.header_hash not in self.cached_blocks:
|
||||
self.log.warning("Do not have header for removals")
|
||||
return
|
||||
block_record, header_block = self.cached_blocks[response.header_hash]
|
||||
assert response.height == block_record.height
|
||||
|
||||
removals: List[bytes32]
|
||||
if response.proofs is None:
|
||||
# Find our removals
|
||||
all_coins: List[Coin] = []
|
||||
for coin_name, coin in response.coins:
|
||||
if coin is not None:
|
||||
all_coins.append(coin)
|
||||
removals = [
|
||||
c.name()
|
||||
for c in await self.wallet_state_manager.get_relevant_removals(
|
||||
all_coins
|
||||
)
|
||||
]
|
||||
else:
|
||||
removals = []
|
||||
assert len(response.coins) == len(response.proofs)
|
||||
for i in range(len(response.coins)):
|
||||
# Coins are in the same order as proofs
|
||||
assert response.coins[i][0] == response.proofs[i][0]
|
||||
coin = response.coins[i][1]
|
||||
if coin is None:
|
||||
assert confirm_not_included_already_hashed(
|
||||
header_block.header.data.removals_root,
|
||||
response.coins[i][0],
|
||||
response.proofs[i][1],
|
||||
)
|
||||
else:
|
||||
assert response.coins[i][0] == coin.name
|
||||
assert confirm_included_already_hashed(
|
||||
header_block.header.data.removals_root,
|
||||
coin.name(),
|
||||
response.proofs[i][1],
|
||||
)
|
||||
removals.append(response.coins[i][0])
|
||||
additions = self.cached_additions.get(response.header_hash, [])
|
||||
new_br = BlockRecord(
|
||||
block_record.header_hash,
|
||||
block_record.prev_header_hash,
|
||||
block_record.height,
|
||||
block_record.weight,
|
||||
additions,
|
||||
removals,
|
||||
)
|
||||
self.cached_blocks[response.header_hash] = (new_br, header_block)
|
||||
self.cached_removals[response.header_hash] = removals
|
||||
|
||||
if response.header_hash in self.cached_additions:
|
||||
# We have collected all three things: header, additions, and removals. Can proceed.
|
||||
# Otherwise, we wait for the additions to arrive
|
||||
async for msg in self._block_finished(new_br, header_block):
|
||||
yield msg
|
||||
|
||||
@api_request
|
||||
async def reject_removals_request(
|
||||
self, response: wallet_protocol.RejectRemovalsRequest
|
||||
):
|
||||
# TODO(mariano): implement
|
||||
pass
|
||||
self.log.error("Removals request rejected")
|
||||
|
||||
# @api_request
|
||||
# async def received_body(self, response: wallet_protocol.RespondBody):
|
||||
# """
|
||||
# Called when body is received from the FullNode
|
||||
# """
|
||||
@api_request
|
||||
async def respond_additions(self, response: wallet_protocol.RespondAdditions):
|
||||
if response.header_hash not in self.cached_blocks:
|
||||
self.log.warning("Do not have header for additions")
|
||||
return
|
||||
block_record, header_block = self.cached_blocks[response.header_hash]
|
||||
assert response.height == block_record.height
|
||||
|
||||
# # Retry sending queued up transactions
|
||||
# await self.retry_send_queue()
|
||||
additions: List[Coin]
|
||||
if response.proofs is None:
|
||||
# Find our removals
|
||||
all_coins: List[Coin] = []
|
||||
for puzzle_hash, coin_list_0 in response.coins:
|
||||
all_coins += coin_list_0
|
||||
additions = await self.wallet_state_manager.get_relevant_additions(
|
||||
all_coins
|
||||
)
|
||||
else:
|
||||
additions = []
|
||||
assert len(response.coins) == len(response.proofs)
|
||||
for i in range(len(response.coins)):
|
||||
assert response.coins[i][0] == response.proofs[i][0]
|
||||
coin_list_1: List[Coin] = response.coins[i][1]
|
||||
puzzle_hash_proof: bytes32 = response.proofs[i][1]
|
||||
coin_list_proof: Optional[bytes32] = response.proofs[i][2]
|
||||
if len(coin_list_1) == 0:
|
||||
# Verify exclusion proof for puzzle hash
|
||||
assert confirm_not_included_already_hashed(
|
||||
header_block.header.data.additions_root,
|
||||
response.coins[i][0],
|
||||
puzzle_hash_proof,
|
||||
)
|
||||
else:
|
||||
# Verify inclusion proof for puzzle hash
|
||||
assert confirm_included_already_hashed(
|
||||
header_block.header.data.additions_root,
|
||||
response.coins[i][0],
|
||||
puzzle_hash_proof,
|
||||
)
|
||||
# Verify inclusion proof for coin list
|
||||
assert confirm_included_already_hashed(
|
||||
header_block.header.data.additions_root,
|
||||
hash_coin_list(coin_list_1),
|
||||
coin_list_proof,
|
||||
)
|
||||
for coin in coin_list_1:
|
||||
assert coin.puzzle_hash == response.coins[i][0]
|
||||
additions += coin_list_1
|
||||
removals = self.cached_removals.get(response.header_hash, [])
|
||||
new_br = BlockRecord(
|
||||
block_record.header_hash,
|
||||
block_record.prev_header_hash,
|
||||
block_record.height,
|
||||
block_record.weight,
|
||||
additions,
|
||||
removals,
|
||||
)
|
||||
self.cached_blocks[response.header_hash] = (new_br, header_block)
|
||||
self.cached_additions[response.header_hash] = additions
|
||||
|
||||
# additions: List[Coin] = []
|
||||
if response.header_hash in self.cached_removals:
|
||||
# We have collected all three things: header, additions, and removals. Can proceed.
|
||||
# Otherwise, we wait for the removals to arrive
|
||||
async for msg in self._block_finished(new_br, header_block):
|
||||
yield msg
|
||||
|
||||
# if await self.wallet.can_generate_puzzle_hash(
|
||||
# response.header.data.coinbase.puzzle_hash
|
||||
# ):
|
||||
# await self.wallet_state_manager.coin_added(
|
||||
# response.header.data.coinbase, response.height, True
|
||||
# )
|
||||
# if await self.wallet.can_generate_puzzle_hash(
|
||||
# response.header.data.fees_coin.puzzle_hash
|
||||
# ):
|
||||
# await self.wallet_state_manager.coin_added(
|
||||
# response.header.data.fees_coin, response.height, True
|
||||
# )
|
||||
|
||||
# npc_list: List[NPC]
|
||||
# if response.transactions_generator:
|
||||
# error, npc_list, cost = get_name_puzzle_conditions(
|
||||
# response.transactions_generator
|
||||
# )
|
||||
|
||||
# additions.extend(additions_for_npc(npc_list))
|
||||
|
||||
# for added_coin in additions:
|
||||
# if await self.wallet.can_generate_puzzle_hash(added_coin.puzzle_hash):
|
||||
# await self.wallet_state_manager.coin_added(
|
||||
# added_coin, response.height, False
|
||||
# )
|
||||
|
||||
# for npc in npc_list:
|
||||
# if await self.wallet.can_generate_puzzle_hash(npc.puzzle_hash):
|
||||
# await self.wallet_state_manager.coin_removed(
|
||||
# npc.coin_name, response.height
|
||||
# )
|
||||
|
||||
# async def retry_send_queue(self):
|
||||
# records = await self.wallet_state_manager.get_send_queue()
|
||||
# for record in records:
|
||||
# if record.spend_bundle:
|
||||
# await self._send_transaction(record.spend_bundle)
|
||||
|
||||
# async def _send_transaction(self, spend_bundle: SpendBundle):
|
||||
# """ Sends spendbundle to connected full Nodes."""
|
||||
# await self.wallet_state_manager.add_pending_transaction(spend_bundle)
|
||||
|
||||
# msg = OutboundMessage(
|
||||
# NodeType.FULL_NODE,
|
||||
# Message("wallet_transaction", spend_bundle),
|
||||
# Delivery.BROADCAST,
|
||||
# )
|
||||
# if self.server:
|
||||
# async for reply in self.server.push_message(msg):
|
||||
# self.log.info(reply)
|
||||
@api_request
|
||||
async def reject_additions_request(
|
||||
self, response: wallet_protocol.RejectAdditionsRequest
|
||||
):
|
||||
# TODO(mariano): implement
|
||||
self.log.error("Additions request rejected")
|
||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
|
||||
from typing import Dict, Optional, List, Set, Tuple
|
||||
import logging
|
||||
|
||||
import asyncio
|
||||
from chiabip158 import PyBIP158
|
||||
|
||||
from src.types.hashable.coin import Coin
|
||||
|
@ -39,6 +39,9 @@ class WalletStateManager:
|
|||
lca: Optional[bytes32]
|
||||
start_index: int
|
||||
|
||||
# Makes sure only one asyncio thread is changing the blockchain state at one time
|
||||
lock: asyncio.Lock
|
||||
|
||||
log: logging.Logger
|
||||
|
||||
# TODO Don't allow user to send tx until wallet is synced
|
||||
|
@ -63,12 +66,14 @@ class WalletStateManager:
|
|||
self.log = logging.getLogger(name)
|
||||
else:
|
||||
self.log = logging.getLogger(__name__)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
self.wallet_store = await WalletStore.create(db_path)
|
||||
self.tx_store = await WalletTransactionStore.create(db_path)
|
||||
self.puzzle_store = await WalletPuzzleStore.create(db_path)
|
||||
|
||||
self.lca = None
|
||||
self.synced = False
|
||||
self.height_to_hash = {}
|
||||
self.block_records = await self.wallet_store.get_lca_path()
|
||||
genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"])
|
||||
|
||||
|
@ -281,62 +286,96 @@ class WalletStateManager:
|
|||
records = await self.tx_store.get_all_transactions()
|
||||
return records
|
||||
|
||||
def find_fork_point(self, alternate_chain: List[bytes32]) -> uint32:
|
||||
"""
|
||||
Takes in an alternate blockchain (headers), and compares it to self. Returns the last header
|
||||
where both blockchains are equal.
|
||||
"""
|
||||
lca: BlockRecord = self.block_records[self.lca]
|
||||
|
||||
if lca.height >= len(alternate_chain) - 1:
|
||||
raise ValueError("Alternate chain is shorter")
|
||||
low: uint32 = uint32(0)
|
||||
high = lca.height
|
||||
while low + 1 < high:
|
||||
mid = uint32((low + high) // 2)
|
||||
if self.height_to_hash[uint32(mid)] != alternate_chain[mid]:
|
||||
high = mid
|
||||
else:
|
||||
low = mid
|
||||
if low == high and low == 0:
|
||||
assert self.height_to_hash[uint32(0)] == alternate_chain[0]
|
||||
return uint32(0)
|
||||
assert low + 1 == high
|
||||
if self.height_to_hash[uint32(low)] == alternate_chain[low]:
|
||||
if self.height_to_hash[uint32(high)] == alternate_chain[high]:
|
||||
return high
|
||||
else:
|
||||
return low
|
||||
elif low > 0:
|
||||
assert self.height_to_hash[uint32(low - 1)] == alternate_chain[low - 1]
|
||||
return uint32(low - 1)
|
||||
else:
|
||||
raise ValueError("Invalid genesis block")
|
||||
|
||||
async def receive_block(
|
||||
self, block: BlockRecord, header_block: Optional[HeaderBlock] = None,
|
||||
) -> ReceiveBlockResult:
|
||||
if block.header_hash in self.block_records:
|
||||
return ReceiveBlockResult.ALREADY_HAVE_BLOCK
|
||||
async with self.lock:
|
||||
if block.header_hash in self.block_records:
|
||||
return ReceiveBlockResult.ALREADY_HAVE_BLOCK
|
||||
|
||||
if block.prev_header_hash not in self.block_records or block.height == 0:
|
||||
return ReceiveBlockResult.DISCONNECTED_BLOCK
|
||||
if block.prev_header_hash not in self.block_records and block.height != 0:
|
||||
return ReceiveBlockResult.DISCONNECTED_BLOCK
|
||||
|
||||
if header_block is not None:
|
||||
# TODO: validate header block
|
||||
pass
|
||||
if header_block is not None:
|
||||
# TODO: validate header block
|
||||
pass
|
||||
|
||||
self.block_records[block.header_hash] = block
|
||||
await self.wallet_store.add_block_record(block, False)
|
||||
self.block_records[block.header_hash] = block
|
||||
await self.wallet_store.add_block_record(block, False)
|
||||
|
||||
# Genesis case
|
||||
if self.lca is None:
|
||||
assert block.height == 0
|
||||
await self.wallet_store.add_block_to_path(block.header_hash)
|
||||
self.lca = block.header_hash
|
||||
for coin in block.additions:
|
||||
await self.coin_added(coin, block.height, False)
|
||||
for coin_name in block.removals:
|
||||
await self.coin_removed(coin_name, block.height)
|
||||
self.height_to_hash[uint32(0)] = block.header_hash
|
||||
return ReceiveBlockResult.ADDED_TO_HEAD
|
||||
# Genesis case
|
||||
if self.lca is None:
|
||||
assert block.height == 0
|
||||
await self.wallet_store.add_block_to_path(block.header_hash)
|
||||
self.lca = block.header_hash
|
||||
for coin in block.additions:
|
||||
await self.coin_added(coin, block.height, False)
|
||||
for coin_name in block.removals:
|
||||
await self.coin_removed(coin_name, block.height)
|
||||
self.height_to_hash[uint32(0)] = block.header_hash
|
||||
return ReceiveBlockResult.ADDED_TO_HEAD
|
||||
|
||||
# Not genesis, updated LCA
|
||||
if block.weight > self.block_records[self.lca].weight:
|
||||
# Not genesis, updated LCA
|
||||
if block.weight > self.block_records[self.lca].weight:
|
||||
|
||||
fork_h = self.find_fork_for_lca(block)
|
||||
await self.reorg_rollback(fork_h)
|
||||
fork_h = self.find_fork_for_lca(block)
|
||||
await self.reorg_rollback(fork_h)
|
||||
|
||||
# Add blocks between fork point and new lca
|
||||
fork_hash = self.height_to_hash[fork_h]
|
||||
blocks_to_add: List[BlockRecord] = []
|
||||
tip_hash: bytes32 = block.header_hash
|
||||
while True:
|
||||
if tip_hash == fork_hash:
|
||||
break
|
||||
record = self.block_records[tip_hash]
|
||||
blocks_to_add.append(record)
|
||||
tip_hash = record.prev_header_hash
|
||||
blocks_to_add.reverse()
|
||||
# Add blocks between fork point and new lca
|
||||
fork_hash = self.height_to_hash[fork_h]
|
||||
blocks_to_add: List[BlockRecord] = []
|
||||
tip_hash: bytes32 = block.header_hash
|
||||
while True:
|
||||
if tip_hash == fork_hash:
|
||||
break
|
||||
record = self.block_records[tip_hash]
|
||||
blocks_to_add.append(record)
|
||||
tip_hash = record.prev_header_hash
|
||||
blocks_to_add.reverse()
|
||||
|
||||
for path_block in blocks_to_add:
|
||||
self.height_to_hash[path_block.height] = path_block.header_hash
|
||||
await self.wallet_store.add_block_to_path(path_block.header_hash)
|
||||
for coin in path_block.additions:
|
||||
await self.coin_added(coin, path_block.height, False)
|
||||
for coin_name in path_block.removals:
|
||||
await self.coin_removed(coin_name, path_block.height)
|
||||
return ReceiveBlockResult.ADDED_TO_HEAD
|
||||
for path_block in blocks_to_add:
|
||||
self.height_to_hash[path_block.height] = path_block.header_hash
|
||||
await self.wallet_store.add_block_to_path(path_block.header_hash)
|
||||
for coin in path_block.additions:
|
||||
await self.coin_added(coin, path_block.height, False)
|
||||
for coin_name in path_block.removals:
|
||||
await self.coin_removed(coin_name, path_block.height)
|
||||
self.lca = block.header_hash
|
||||
return ReceiveBlockResult.ADDED_TO_HEAD
|
||||
|
||||
return ReceiveBlockResult.ADDED_AS_ORPHAN
|
||||
return ReceiveBlockResult.ADDED_AS_ORPHAN
|
||||
|
||||
def find_fork_for_lca(self, new_lca: BlockRecord) -> uint32:
|
||||
""" Tries to find height where new chain (current) diverged from the old chain where old_lca was the LCA"""
|
||||
|
@ -358,8 +397,7 @@ class WalletStateManager:
|
|||
self, transactions_fitler: bytes
|
||||
) -> Tuple[List[bytes32], List[bytes32]]:
|
||||
""" Returns a list of our coin ids, and a list of puzzle_hashes that positively match with provided filter. """
|
||||
|
||||
tx_filter = PyBIP158(transactions_fitler)
|
||||
tx_filter = PyBIP158([b for b in transactions_fitler])
|
||||
my_coin_records: Set[
|
||||
CoinRecord
|
||||
] = await self.wallet_store.get_coin_records_by_spent(False)
|
||||
|
|
|
@ -16,7 +16,6 @@ class WalletStore:
|
|||
|
||||
db_connection: aiosqlite.Connection
|
||||
# Whether or not we are syncing
|
||||
sync_mode: bool = False
|
||||
lock: asyncio.Lock
|
||||
coin_record_cache: Dict[str, CoinRecord]
|
||||
cache_size: uint32
|
||||
|
@ -221,7 +220,6 @@ class WalletStore:
|
|||
if br.height > max_height:
|
||||
max_height = br.height
|
||||
# Makes sure there's exactly one block per height
|
||||
print(max_height, len(rows))
|
||||
assert max_height == len(rows) - 1
|
||||
return hash_to_br
|
||||
|
||||
|
@ -232,7 +230,7 @@ class WalletStore:
|
|||
block_record.header_hash.hex(),
|
||||
block_record.height,
|
||||
in_lca_path,
|
||||
block_record,
|
||||
bytes(block_record),
|
||||
),
|
||||
)
|
||||
await cursor.close()
|
||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
from src.types.peer_info import PeerInfo
|
||||
from src.protocols import full_node_protocol
|
||||
from src.util.ints import uint16
|
||||
from tests.setup_nodes import setup_two_nodes, test_constants, bt
|
||||
from tests.setup_nodes import setup_two_nodes, setup_node_and_wallet, test_constants, bt
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -21,6 +21,11 @@ class TestFullSync:
|
|||
async for _ in setup_two_nodes():
|
||||
yield _
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def wallet_node(self):
|
||||
async for _ in setup_node_and_wallet():
|
||||
yield _
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_sync(self, two_nodes):
|
||||
num_blocks = 100
|
||||
|
@ -56,6 +61,43 @@ class TestFullSync:
|
|||
f"Took too long to process blocks, stopped at: {time.time() - start_unf}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_sync_wallet(self, wallet_node):
|
||||
num_blocks = 25
|
||||
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
|
||||
for i in range(1, num_blocks):
|
||||
async for _ in full_node_1.respond_block(
|
||||
full_node_protocol.RespondBlock(blocks[i])
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
start_unf = time.time()
|
||||
|
||||
while time.time() - start_unf < 60:
|
||||
# The second node should eventually catch up to the first one, and have the
|
||||
# same tip at height num_blocks - 1.
|
||||
if (
|
||||
wallet_node.wallet_state_manager.block_records[
|
||||
wallet_node.wallet_state_manager.lca
|
||||
].height
|
||||
== num_blocks - 7
|
||||
):
|
||||
print(f"Time taken to sync {num_blocks} is {time.time() - start_unf}")
|
||||
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
raise Exception(
|
||||
f"Took too long to process blocks, stopped at: {time.time() - start_unf}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_sync(self, two_nodes):
|
||||
num_blocks = 10
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from typing import Any, Dict
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import blspy
|
||||
from secrets import token_bytes
|
||||
|
||||
from src.full_node.blockchain import Blockchain
|
||||
from src.full_node.mempool_manager import MempoolManager
|
||||
|
@ -8,6 +10,7 @@ from src.full_node.store import FullNodeStore
|
|||
from src.full_node.full_node import FullNode
|
||||
from src.server.connection import NodeType
|
||||
from src.server.server import ChiaServer
|
||||
from src.wallet.wallet_node import WalletNode
|
||||
from src.types.full_block import FullBlock
|
||||
from src.full_node.coin_store import CoinStore
|
||||
from tests.block_tools import BlockTools
|
||||
|
@ -91,6 +94,33 @@ async def setup_full_node(db_name, port, introducer_port=None, dic={}):
|
|||
Path(db_name).unlink()
|
||||
|
||||
|
||||
async def setup_wallet_node(port, introducer_port=None, dic={}):
|
||||
config = load_config("config.yaml", "wallet")
|
||||
key_config = {
|
||||
"wallet_sk": bytes(blspy.ExtendedPrivateKey.from_seed(b"1234")).hex(),
|
||||
}
|
||||
test_constants_copy = test_constants.copy()
|
||||
for k in dic.keys():
|
||||
test_constants_copy[k] = dic[k]
|
||||
db_path = "test-wallet-db" + token_bytes(32).hex()
|
||||
if Path(db_path).exists():
|
||||
Path(db_path).unlink()
|
||||
|
||||
wallet = await WalletNode.create(
|
||||
config, key_config, db_path=db_path, override_constants=test_constants_copy
|
||||
)
|
||||
server = ChiaServer(port, wallet, NodeType.WALLET)
|
||||
wallet.set_server(server)
|
||||
|
||||
yield (wallet, server)
|
||||
|
||||
server.close_all()
|
||||
await wallet.wallet_state_manager.clear_all_stores()
|
||||
await wallet.wallet_state_manager.close_all_stores()
|
||||
Path(db_path).unlink()
|
||||
await server.await_closed()
|
||||
|
||||
|
||||
async def setup_harvester(port, dic={}):
|
||||
config = load_config("config.yaml", "harvester")
|
||||
|
||||
|
@ -195,6 +225,24 @@ async def setup_two_nodes(dic={}):
|
|||
pass
|
||||
|
||||
|
||||
async def setup_node_and_wallet(dic={}):
|
||||
node_iters = [
|
||||
setup_full_node("blockchain_test.db", 21234, dic=dic),
|
||||
setup_wallet_node(21235, dic=dic),
|
||||
]
|
||||
|
||||
full_node, s1 = await node_iters[0].__anext__()
|
||||
wallet, s2 = await node_iters[1].__anext__()
|
||||
|
||||
yield (full_node, wallet, s1, s2)
|
||||
|
||||
for node_iter in node_iters:
|
||||
try:
|
||||
await node_iter.__anext__()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
|
||||
async def setup_full_system(dic={}):
|
||||
node_iters = [
|
||||
setup_introducer(21233),
|
||||
|
|
Loading…
Reference in New Issue