pyblack autoformatting

This commit is contained in:
Alex Wice 2019-11-18 13:50:31 +09:00
parent 8be7181c9d
commit 4f9ba50b15
39 changed files with 1896 additions and 805 deletions

View File

@ -2,17 +2,25 @@
from setuptools import setup
dependencies = ["blspy", "cbor2", "pyyaml", "asyncssh"]
dev_dependencies = ["pytest", "flake8", "mypy", "isort", "autoflake", "black", "pytest-asyncio"]
dev_dependencies = [
"pytest",
"flake8",
"mypy",
"isort",
"autoflake",
"black",
"pytest-asyncio",
]
setup(
name='chiablockchain',
version='0.1.2',
author='Mariano Sorgente',
author_email='mariano@chia.net',
description='Chia proof of space plotting, proving, and verifying (wraps C++)',
license='Apache License',
python_requires='>=3.7, <4',
keywords='chia blockchain node',
name="chiablockchain",
version="0.1.2",
author="Mariano Sorgente",
author_email="mariano@chia.net",
description="Chia proof of space plotting, proving, and verifying (wraps C++)",
license="Apache License",
python_requires=">=3.7, <4",
keywords="chia blockchain node",
install_requires=dependencies + dev_dependencies,
long_description=open("README.md").read(),
zip_safe=False,

View File

@ -7,8 +7,10 @@ import blspy
from src.consensus.block_rewards import calculate_block_reward
from src.consensus.constants import constants as consensus_constants
from src.consensus.pot_iterations import (calculate_ips_from_iterations,
calculate_iterations_quality)
from src.consensus.pot_iterations import (
calculate_ips_from_iterations,
calculate_iterations_quality,
)
from src.db.database import FullNodeStore
from src.types.full_block import FullBlock
from src.types.header_block import HeaderBlock
@ -25,11 +27,14 @@ class ReceiveBlockResult(Enum):
showing whether the block was added to the chain (extending a head or not),
and if not, why it was not added.
"""
ADDED_TO_HEAD = 1 # Added to one of the heads, this block is now a new head
ADDED_AS_ORPHAN = 2 # Added as an orphan/stale block (block that is not a head or ancestor of a head)
INVALID_BLOCK = 3 # Block was not added because it was invalid
ALREADY_HAVE_BLOCK = 4 # Block is already present in this blockchain
DISCONNECTED_BLOCK = 5 # Block's parent (previous pointer) is not in this blockchain
DISCONNECTED_BLOCK = (
5 # Block's parent (previous pointer) is not in this blockchain
)
class Blockchain:
@ -50,7 +55,7 @@ class Blockchain:
self.heads = [block]
# TODO: are cases where the blockchain "fans out" handled appropriately?
self.height_to_hash[block.height] = block.header_hash
if not self.heads:
self.genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"])
@ -72,7 +77,7 @@ class Blockchain:
True iff the block is the direct ancestor of a head.
"""
for head in self.tips:
if (block.prev_header_hash == head.header_hash):
if block.prev_header_hash == head.header_hash:
return True
return False
@ -86,24 +91,34 @@ class Blockchain:
else:
return None
async def get_header_blocks_by_height(self, heights: List[uint64], tip_header_hash: bytes32) -> List[HeaderBlock]:
async def get_header_blocks_by_height(
self, heights: List[uint64], tip_header_hash: bytes32
) -> List[HeaderBlock]:
"""
Returns a list of header blocks, one for each height requested.
"""
# TODO: optimize, don't look at all blocks
sorted_heights = sorted([(height, index) for index, height in enumerate(heights)], reverse=True)
sorted_heights = sorted(
[(height, index) for index, height in enumerate(heights)], reverse=True
)
curr_full_block: Optional[FullBlock] = await self.store.get_block(tip_header_hash)
curr_full_block: Optional[FullBlock] = await self.store.get_block(
tip_header_hash
)
if not curr_full_block:
raise BlockNotInBlockchain(f"Header hash {tip_header_hash} not present in chain.")
raise BlockNotInBlockchain(
f"Header hash {tip_header_hash} not present in chain."
)
curr_block = curr_full_block.header_block
headers: List[Tuple[int, HeaderBlock]] = []
for height, index in sorted_heights:
if height > curr_block.height:
raise ValueError("Height is not valid for tip {tip_header_hash}")
while height < curr_block.height:
curr_block = (await self.store.get_block(curr_block.header.data.prev_header_hash)).header_block
curr_block = (
await self.store.get_block(curr_block.header.data.prev_header_hash)
).header_block
headers.append((index, curr_block))
return [b for index, b in sorted(headers)]
@ -118,21 +133,32 @@ class Blockchain:
high = lca.height
while low + 1 < high:
mid = (low + high) // 2
if self.height_to_hash[uint64(mid)] != alternate_chain[mid].header.get_hash():
if (
self.height_to_hash[uint64(mid)]
!= alternate_chain[mid].header.get_hash()
):
high = mid
else:
low = mid
if low == high and low == 0:
assert self.height_to_hash[uint64(0)] == alternate_chain[0].header.get_hash()
assert (
self.height_to_hash[uint64(0)] == alternate_chain[0].header.get_hash()
)
return alternate_chain[0]
assert low + 1 == high
if self.height_to_hash[uint64(low)] == alternate_chain[low].header.get_hash():
if self.height_to_hash[uint64(high)] == alternate_chain[high].header.get_hash():
if (
self.height_to_hash[uint64(high)]
== alternate_chain[high].header.get_hash()
):
return alternate_chain[high]
else:
return alternate_chain[low]
elif low > 0:
assert self.height_to_hash[uint64(low - 1)] == alternate_chain[low - 1].header.get_hash()
assert (
self.height_to_hash[uint64(low - 1)]
== alternate_chain[low - 1].header.get_hash()
)
return alternate_chain[low - 1]
else:
raise ValueError("Invalid genesis block")
@ -154,18 +180,29 @@ class Blockchain:
# Epochs are diffined as intervals of DIFFICULTY_EPOCH blocks, inclusive and indexed at 0.
# For example, [0-2047], [2048-4095], etc. The difficulty changes DIFFICULTY_DELAY into the
# epoch, as opposed to the first block (as in Bitcoin).
elif next_height % self.constants["DIFFICULTY_EPOCH"] != self.constants["DIFFICULTY_DELAY"]:
elif (
next_height % self.constants["DIFFICULTY_EPOCH"]
!= self.constants["DIFFICULTY_DELAY"]
):
# Not at a point where difficulty would change
prev_block = await self.store.get_block(block.prev_header_hash)
if prev_block is None:
raise Exception("Previous block is invalid.")
return uint64(block.header_block.challenge.total_weight - prev_block.header_block.challenge.total_weight)
return uint64(
block.header_block.challenge.total_weight
- prev_block.header_block.challenge.total_weight
)
# old diff curr diff new diff
# ----------|-----|----------------------|-----|-----...
# h1 h2 h3 i-1
# Height1 is the last block 2 epochs ago, so we can include the time to mine 1st block in previous epoch
height1 = uint64(next_height - self.constants["DIFFICULTY_EPOCH"] - self.constants["DIFFICULTY_DELAY"] - 1)
height1 = uint64(
next_height
- self.constants["DIFFICULTY_EPOCH"]
- self.constants["DIFFICULTY_DELAY"]
- 1
)
# Height2 is the DIFFICULTY DELAYth block in the previous epoch
height2 = uint64(next_height - self.constants["DIFFICULTY_EPOCH"] - 1)
# Height3 is the last block in the previous epoch
@ -175,11 +212,17 @@ class Blockchain:
# current difficulty
block1, block2, block3 = None, None, None
if block.header_block not in self.get_current_tips() or height3 not in self.height_to_hash:
if (
block.header_block not in self.get_current_tips()
or height3 not in self.height_to_hash
):
# This means we are either on a fork, or on one of the chains, but after the LCA,
# so we manually backtrack.
curr = block
while (curr.height not in self.height_to_hash or self.height_to_hash[curr.height] != curr.header_hash):
while (
curr.height not in self.height_to_hash
or self.height_to_hash[curr.height] != curr.header_hash
):
if curr.height == height1:
block1 = curr
elif curr.height == height2:
@ -210,29 +253,50 @@ class Blockchain:
# took constants["BLOCK_TIME_TARGET"] seconds to mine.
genesis = await self.store.get_block(self.height_to_hash[uint64(0)])
assert genesis is not None
timestamp1 = (genesis.header_block.header.data.timestamp - self.constants["BLOCK_TIME_TARGET"])
timestamp1 = (
genesis.header_block.header.data.timestamp
- self.constants["BLOCK_TIME_TARGET"]
)
timestamp2 = block2.header_block.header.data.timestamp # i - 2048 + 512 - 1
timestamp3 = block3.header_block.header.data.timestamp # i - 512 - 1
# Numerator fits in 128 bits, so big int is not necessary
# We multiply by the denominators here, so we only have one fraction in the end (avoiding floating point)
term1 = (self.constants["DIFFICULTY_DELAY"] * Tp * (timestamp3 - timestamp2) *
self.constants["BLOCK_TIME_TARGET"])
term2 = ((self.constants["DIFFICULTY_WARP_FACTOR"] - 1) * (self.constants["DIFFICULTY_EPOCH"] -
self.constants["DIFFICULTY_DELAY"]) * Tc
* (timestamp2 - timestamp1) * self.constants["BLOCK_TIME_TARGET"])
term1 = (
self.constants["DIFFICULTY_DELAY"]
* Tp
* (timestamp3 - timestamp2)
* self.constants["BLOCK_TIME_TARGET"]
)
term2 = (
(self.constants["DIFFICULTY_WARP_FACTOR"] - 1)
* (self.constants["DIFFICULTY_EPOCH"] - self.constants["DIFFICULTY_DELAY"])
* Tc
* (timestamp2 - timestamp1)
* self.constants["BLOCK_TIME_TARGET"]
)
# Round down after the division
new_difficulty: uint64 = uint64((term1 + term2) //
(self.constants["DIFFICULTY_WARP_FACTOR"] *
(timestamp3 - timestamp2) *
(timestamp2 - timestamp1)))
new_difficulty: uint64 = uint64(
(term1 + term2)
// (
self.constants["DIFFICULTY_WARP_FACTOR"]
* (timestamp3 - timestamp2)
* (timestamp2 - timestamp1)
)
)
# Only change by a max factor, to prevent attacks, as in greenpaper, and must be at least 1
if new_difficulty >= Tc:
return min(new_difficulty, uint64(self.constants["DIFFICULTY_FACTOR"] * Tc))
else:
return max([uint64(1), new_difficulty, uint64(Tc // self.constants["DIFFICULTY_FACTOR"])])
return max(
[
uint64(1),
new_difficulty,
uint64(Tc // self.constants["DIFFICULTY_FACTOR"]),
]
)
async def get_next_ips(self, header_hash) -> uint64:
"""
@ -253,11 +317,18 @@ class Blockchain:
raise Exception("Previous block is invalid.")
proof_of_space = block.header_block.proof_of_space
difficulty = await self.get_next_difficulty(prev_block.header_hash)
iterations = block.header_block.challenge.total_iters - prev_block.header_block.challenge.total_iters
prev_ips = calculate_ips_from_iterations(proof_of_space, difficulty, iterations,
self.constants["MIN_BLOCK_TIME"])
iterations = (
block.header_block.challenge.total_iters
- prev_block.header_block.challenge.total_iters
)
prev_ips = calculate_ips_from_iterations(
proof_of_space, difficulty, iterations, self.constants["MIN_BLOCK_TIME"]
)
if next_height % self.constants["DIFFICULTY_EPOCH"] != self.constants["DIFFICULTY_DELAY"]:
if (
next_height % self.constants["DIFFICULTY_EPOCH"]
!= self.constants["DIFFICULTY_DELAY"]
):
# Not at a point where ips would change, so return the previous ips
# TODO: cache this for efficiency
return prev_ips
@ -268,16 +339,27 @@ class Blockchain:
# block of the last epochs. Basically, it's total iterations over time, of previous epoch.
# Height1 is the last block 2 epochs ago, so we can include the iterations taken for mining first block in epoch
height1 = uint64(next_height - self.constants["DIFFICULTY_EPOCH"] - self.constants["DIFFICULTY_DELAY"] - 1)
height1 = uint64(
next_height
- self.constants["DIFFICULTY_EPOCH"]
- self.constants["DIFFICULTY_DELAY"]
- 1
)
# Height2 is the last block in the previous epoch
height2 = uint64(next_height - self.constants["DIFFICULTY_DELAY"] - 1)
block1, block2 = None, None
if block.header_block not in self.get_current_tips() or height2 not in self.height_to_hash:
if (
block.header_block not in self.get_current_tips()
or height2 not in self.height_to_hash
):
# This means we are either on a fork, or on one of the chains, but after the LCA,
# so we manually backtrack.
curr = block
while (curr.height not in self.height_to_hash or self.height_to_hash[curr.height] != curr.header_hash):
while (
curr.height not in self.height_to_hash
or self.height_to_hash[curr.height] != curr.header_hash
):
if curr.height == height1:
block1 = curr
elif curr.height == height2:
@ -300,7 +382,10 @@ class Blockchain:
# took constants["BLOCK_TIME_TARGET"] seconds to mine.
genesis = await self.store.get_block(self.height_to_hash[uint64(0)])
assert genesis is not None
timestamp1 = genesis.header_block.header.data.timestamp - self.constants["BLOCK_TIME_TARGET"]
timestamp1 = (
genesis.header_block.header.data.timestamp
- self.constants["BLOCK_TIME_TARGET"]
)
iters1 = genesis.header_block.challenge.total_iters
timestamp2 = block2.header_block.header.data.timestamp
@ -312,7 +397,9 @@ class Blockchain:
if new_ips >= prev_ips:
return min(new_ips, uint64(self.constants["IPS_FACTOR"] * new_ips))
else:
return max([uint64(1), new_ips, uint64(prev_ips // self.constants["IPS_FACTOR"])])
return max(
[uint64(1), new_ips, uint64(prev_ips // self.constants["IPS_FACTOR"])]
)
async def receive_block(self, block: FullBlock) -> ReceiveBlockResult:
"""
@ -337,7 +424,9 @@ class Blockchain:
else:
return ReceiveBlockResult.ADDED_AS_ORPHAN
async def validate_unfinished_block(self, block: FullBlock, genesis: bool = False) -> bool:
async def validate_unfinished_block(
self, block: FullBlock, genesis: bool = False
) -> bool:
"""
Block validation algorithm. Returns true if the candidate block is fully valid
(except for proof of time). The same as validate_block, but without proof of time
@ -362,12 +451,18 @@ class Blockchain:
if not fetched:
break
curr = fetched
if len(last_timestamps) != self.constants["NUMBER_OF_TIMESTAMPS"] and curr.body.coinbase.height != 0:
if (
len(last_timestamps) != self.constants["NUMBER_OF_TIMESTAMPS"]
and curr.body.coinbase.height != 0
):
return False
prev_time: uint64 = uint64(int(sum(last_timestamps) / len(last_timestamps)))
if block.header_block.header.data.timestamp < prev_time:
return False
if block.header_block.header.data.timestamp > time.time() + self.constants["MAX_FUTURE_TIME"]:
if (
block.header_block.header.data.timestamp
> time.time() + self.constants["MAX_FUTURE_TIME"]
):
return False
# 3. Check filter hash is correct TODO
@ -397,8 +492,9 @@ class Blockchain:
# 8. Check harvester signature of header data is valid based on harvester key
if not block.header_block.header.harvester_signature.verify(
[blspy.Util.hash256(block.header_block.header.data.get_hash())],
[block.header_block.proof_of_space.plot_pubkey]):
[blspy.Util.hash256(block.header_block.header.data.get_hash())],
[block.header_block.proof_of_space.plot_pubkey],
):
return False
# 9. Check proof of space based on challenge
@ -416,12 +512,17 @@ class Blockchain:
return False
# 11. Check coinbase amount
if calculate_block_reward(block.body.coinbase.height) != block.body.coinbase.amount:
if (
calculate_block_reward(block.body.coinbase.height)
!= block.body.coinbase.amount
):
return False
# 12. Check coinbase signature with pool pk
if not block.body.coinbase_signature.verify([blspy.Util.hash256(bytes(block.body.coinbase))],
[block.header_block.proof_of_space.pool_pubkey]):
if not block.body.coinbase_signature.verify(
[blspy.Util.hash256(bytes(block.body.coinbase))],
[block.header_block.proof_of_space.pool_pubkey],
):
return False
# TODO: 13a. check transactions
@ -454,7 +555,10 @@ class Blockchain:
# 2. Check proof of space hash
if not block.header_block.challenge or not block.header_block.proof_of_time:
return False
if block.header_block.proof_of_space.get_hash() != block.header_block.challenge.proof_of_space_hash:
if (
block.header_block.proof_of_space.get_hash()
!= block.header_block.challenge.proof_of_space_hash
):
return False
# 3. Check number of iterations on PoT is correct, based on prev block and PoS
@ -463,14 +567,21 @@ class Blockchain:
if pos_quality is None:
return False
number_of_iters: uint64 = calculate_iterations_quality(pos_quality, block.header_block.proof_of_space.size,
difficulty, ips, self.constants["MIN_BLOCK_TIME"])
number_of_iters: uint64 = calculate_iterations_quality(
pos_quality,
block.header_block.proof_of_space.size,
difficulty,
ips,
self.constants["MIN_BLOCK_TIME"],
)
if number_of_iters != block.header_block.proof_of_time.number_of_iterations:
return False
# 4. Check PoT
if not block.header_block.proof_of_time.is_valid(self.constants["DISCRIMINANT_SIZE_BITS"]):
if not block.header_block.proof_of_time.is_valid(
self.constants["DISCRIMINANT_SIZE_BITS"]
):
return False
if block.body.coinbase.height != block.header_block.challenge.height:
@ -482,22 +593,31 @@ class Blockchain:
return False
# 5. and check if PoT.challenge_hash matches
if (block.header_block.proof_of_time.challenge_hash !=
prev_block.header_block.challenge.get_hash()):
if (
block.header_block.proof_of_time.challenge_hash
!= prev_block.header_block.challenge.get_hash()
):
return False
# 6a. Check challenge height = parent height + 1
if block.header_block.challenge.height != prev_block.header_block.challenge.height + 1:
if (
block.header_block.challenge.height
!= prev_block.header_block.challenge.height + 1
):
return False
# 7a. Check challenge total_weight = parent total_weight + difficulty
if (block.header_block.challenge.total_weight !=
prev_block.header_block.challenge.total_weight + difficulty):
if (
block.header_block.challenge.total_weight
!= prev_block.header_block.challenge.total_weight + difficulty
):
return False
# 8a. Check challenge total_iters = parent total_iters + number_iters
if (block.header_block.challenge.total_iters !=
prev_block.header_block.challenge.total_iters + number_of_iters):
if (
block.header_block.challenge.total_iters
!= prev_block.header_block.challenge.total_iters + number_of_iters
):
return False
else:
# 6b. Check challenge height = parent height + 1
@ -514,7 +634,9 @@ class Blockchain:
return True
async def _reconsider_heights(self, old_lca: Optional[FullBlock], new_lca: FullBlock):
async def _reconsider_heights(
self, old_lca: Optional[FullBlock], new_lca: FullBlock
):
"""
Update the mapping from height to block hash, when the lca changes.
"""
@ -525,16 +647,24 @@ class Blockchain:
self.height_to_hash[uint64(curr_new.height)] = curr_new.header_hash
if curr_new.height == 0:
return
curr_new = (await self.store.get_block(curr_new.prev_header_hash)).header_block
curr_new = (
await self.store.get_block(curr_new.prev_header_hash)
).header_block
elif curr_old.height > curr_new.height:
del self.height_to_hash[uint64(curr_old.height)]
curr_old = (await self.store.get_block(curr_old.prev_header_hash)).header_block
curr_old = (
await self.store.get_block(curr_old.prev_header_hash)
).header_block
else:
if curr_new.header_hash == curr_old.header_hash:
return
self.height_to_hash[uint64(curr_new.height)] = curr_new.header_hash
curr_new = (await self.store.get_block(curr_new.prev_header_hash)).header_block
curr_old = (await self.store.get_block(curr_old.prev_header_hash)).header_block
curr_new = (
await self.store.get_block(curr_new.prev_header_hash)
).header_block
curr_old = (
await self.store.get_block(curr_old.prev_header_hash)
).header_block
async def _reconsider_lca(self, genesis: bool):
"""

File diff suppressed because one or more lines are too long

View File

@ -36,41 +36,60 @@ def _quality_to_decimal(quality: bytes32) -> Decimal:
return -Decimal(numerator // denominator) / Decimal(t)
def calculate_iterations_quality(quality: bytes32, size: uint8, difficulty: uint64,
vdf_ips: uint64, min_block_time: uint64) -> uint64:
def calculate_iterations_quality(
quality: bytes32,
size: uint8,
difficulty: uint64,
vdf_ips: uint64,
min_block_time: uint64,
) -> uint64:
"""
Calculates the number of iterations from the quality. The quality is converted to a number
between 0 and 1, then divided by expected plot size, and finally multiplied by the
difficulty.
"""
min_iterations = min_block_time * vdf_ips
dec_iters = (Decimal(int(difficulty) << 32) *
(_quality_to_decimal(quality) / _expected_plot_size(size)))
iters_final = uint64(int(min_iterations + dec_iters.to_integral_exact(rounding=ROUND_UP)))
dec_iters = Decimal(int(difficulty) << 32) * (
_quality_to_decimal(quality) / _expected_plot_size(size)
)
iters_final = uint64(
int(min_iterations + dec_iters.to_integral_exact(rounding=ROUND_UP))
)
assert iters_final >= 1
return iters_final
def calculate_iterations(proof_of_space: ProofOfSpace, difficulty: uint64, vdf_ips: uint64,
min_block_time: uint64) -> uint64:
def calculate_iterations(
proof_of_space: ProofOfSpace,
difficulty: uint64,
vdf_ips: uint64,
min_block_time: uint64,
) -> uint64:
"""
Convenience function to calculate the number of iterations using the proof instead
of the quality. The quality must be retrieved from the proof.
"""
quality: bytes32 = proof_of_space.verify_and_get_quality()
return calculate_iterations_quality(quality, proof_of_space.size, difficulty, vdf_ips, min_block_time)
return calculate_iterations_quality(
quality, proof_of_space.size, difficulty, vdf_ips, min_block_time
)
def calculate_ips_from_iterations(proof_of_space: ProofOfSpace, difficulty: uint64,
iterations: uint64, min_block_time: uint64) -> uint64:
def calculate_ips_from_iterations(
proof_of_space: ProofOfSpace,
difficulty: uint64,
iterations: uint64,
min_block_time: uint64,
) -> uint64:
"""
Using the total number of iterations on a block (which is encoded in the block) along with
other details, we can calculate the VDF speed (iterations per second) used to compute the
constant factor in iterations, which is not written into the block.
"""
quality: bytes32 = proof_of_space.verify_and_get_quality()
dec_iters = (Decimal(int(difficulty) << 32) *
(_quality_to_decimal(quality) / _expected_plot_size(proof_of_space.size)))
dec_iters = Decimal(int(difficulty) << 32) * (
_quality_to_decimal(quality) / _expected_plot_size(proof_of_space.size)
)
iters_rounded = int(dec_iters.to_integral_exact(rounding=ROUND_UP))
min_iterations = uint64(iterations - iters_rounded)
ips = min_iterations / min_block_time

View File

@ -161,11 +161,7 @@ class FullNodeStore(Database):
return self.potential_blocks_received[height]
async def add_candidate_block(
self,
pos_hash: bytes32,
body: Body,
header: HeaderData,
pos: ProofOfSpace,
self, pos_hash: bytes32, body: Body, header: HeaderData, pos: ProofOfSpace,
):
await self.candidate_blocks.find_one_and_update(
{"_id": pos_hash},

View File

@ -11,8 +11,7 @@ from src.consensus.block_rewards import calculate_block_reward
from src.consensus.constants import constants
from src.consensus.pot_iterations import calculate_iterations_quality
from src.protocols import farmer_protocol, harvester_protocol
from src.server.outbound_message import (Delivery, Message, NodeType,
OutboundMessage)
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.types.coinbase import CoinbaseInfo
from src.types.proof_of_space import ProofOfSpace
from src.types.sized_bytes import bytes32
@ -32,7 +31,9 @@ class Farmer:
config_filename = os.path.join(ROOT_DIR, "src", "config", "config.yaml")
key_config_filename = os.path.join(ROOT_DIR, "src", "config", "keys.yaml")
if not os.path.isfile(key_config_filename):
raise RuntimeError("Keys not generated. Run ./src/scripts/regenerate_keys.py.")
raise RuntimeError(
"Keys not generated. Run ./src/scripts/regenerate_keys.py."
)
self.config = safe_load(open(config_filename, "r"))["farmer"]
self.key_config = safe_load(open(key_config_filename, "r"))
self.harvester_responses_header_hash: Dict[bytes32, bytes32] = {}
@ -49,7 +50,9 @@ class Farmer:
self.proof_of_time_estimate_ips: uint64 = uint64(3000)
@api_request
async def challenge_response(self, challenge_response: harvester_protocol.ChallengeResponse):
async def challenge_response(
self, challenge_response: harvester_protocol.ChallengeResponse
):
"""
This is a response from the harvester, for a NewChallenge. Here we check if the proof
of space is sufficiently good, and if so, we ask for the whole proof.
@ -66,36 +69,58 @@ class Farmer:
if difficulty == 0:
raise RuntimeError("Did not find challenge")
number_iters: uint64 = calculate_iterations_quality(challenge_response.quality,
challenge_response.plot_size,
difficulty,
self.proof_of_time_estimate_ips,
constants["MIN_BLOCK_TIME"])
number_iters: uint64 = calculate_iterations_quality(
challenge_response.quality,
challenge_response.plot_size,
difficulty,
self.proof_of_time_estimate_ips,
constants["MIN_BLOCK_TIME"],
)
if height < 300: # As the difficulty adjusts, don't fetch all qualities
if challenge_response.challenge_hash not in self.challenge_to_best_iters:
self.challenge_to_best_iters[challenge_response.challenge_hash] = number_iters
elif number_iters < self.challenge_to_best_iters[challenge_response.challenge_hash]:
self.challenge_to_best_iters[challenge_response.challenge_hash] = number_iters
self.challenge_to_best_iters[
challenge_response.challenge_hash
] = number_iters
elif (
number_iters
< self.challenge_to_best_iters[challenge_response.challenge_hash]
):
self.challenge_to_best_iters[
challenge_response.challenge_hash
] = number_iters
else:
return
estimate_secs: float = number_iters / self.proof_of_time_estimate_ips
log.info(f"Estimate: {estimate_secs}, rate: {self.proof_of_time_estimate_ips}")
if estimate_secs < self.config['pool_share_threshold'] or estimate_secs < self.config['propagate_threshold']:
self.harvester_responses_challenge[challenge_response.quality] = challenge_response.challenge_hash
if (
estimate_secs < self.config["pool_share_threshold"]
or estimate_secs < self.config["propagate_threshold"]
):
self.harvester_responses_challenge[
challenge_response.quality
] = challenge_response.challenge_hash
request = harvester_protocol.RequestProofOfSpace(challenge_response.quality)
yield OutboundMessage(NodeType.HARVESTER, Message("request_proof_of_space", request), Delivery.RESPOND)
yield OutboundMessage(
NodeType.HARVESTER,
Message("request_proof_of_space", request),
Delivery.RESPOND,
)
@api_request
async def respond_proof_of_space(self, response: harvester_protocol.RespondProofOfSpace):
async def respond_proof_of_space(
self, response: harvester_protocol.RespondProofOfSpace
):
"""
This is a response from the harvester with a proof of space. We check it's validity,
and request a pool partial, a header signature, or both, if the proof is good enough.
"""
pool_sks: List[PrivateKey] = [PrivateKey.from_bytes(bytes.fromhex(ce))
for ce in self.key_config["pool_sks"]]
pool_sks: List[PrivateKey] = [
PrivateKey.from_bytes(bytes.fromhex(ce))
for ce in self.key_config["pool_sks"]
]
assert response.proof.pool_pubkey in [sk.get_public_key() for sk in pool_sks]
challenge_hash: bytes32 = self.harvester_responses_challenge[response.quality]
@ -112,34 +137,55 @@ class Farmer:
assert response.quality == computed_quality
self.harvester_responses_proofs[response.quality] = response.proof
self.harvester_responses_proof_hash_to_qual[response.proof.get_hash()] = response.quality
self.harvester_responses_proof_hash_to_qual[
response.proof.get_hash()
] = response.quality
number_iters: uint64 = calculate_iterations_quality(computed_quality,
response.proof.size,
difficulty,
self.proof_of_time_estimate_ips,
constants["MIN_BLOCK_TIME"])
number_iters: uint64 = calculate_iterations_quality(
computed_quality,
response.proof.size,
difficulty,
self.proof_of_time_estimate_ips,
constants["MIN_BLOCK_TIME"],
)
estimate_secs: float = number_iters / self.proof_of_time_estimate_ips
if estimate_secs < self.config['pool_share_threshold']:
request1 = harvester_protocol.RequestPartialProof(response.quality,
sha256(bytes.fromhex(
self.key_config['farmer_target'])).digest())
yield OutboundMessage(NodeType.HARVESTER, Message("request_partial_proof", request1), Delivery.RESPOND)
if estimate_secs < self.config['propagate_threshold']:
if estimate_secs < self.config["pool_share_threshold"]:
request1 = harvester_protocol.RequestPartialProof(
response.quality,
sha256(bytes.fromhex(self.key_config["farmer_target"])).digest(),
)
yield OutboundMessage(
NodeType.HARVESTER,
Message("request_partial_proof", request1),
Delivery.RESPOND,
)
if estimate_secs < self.config["propagate_threshold"]:
if new_proof_height not in self.coinbase_rewards:
log.error(f"Don't have coinbase transaction for height {new_proof_height}, cannot submit PoS")
log.error(
f"Don't have coinbase transaction for height {new_proof_height}, cannot submit PoS"
)
return
coinbase, signature = self.coinbase_rewards[new_proof_height]
request2 = farmer_protocol.RequestHeaderHash(challenge_hash, coinbase, signature,
bytes.fromhex(self.key_config['farmer_target']),
response.proof)
request2 = farmer_protocol.RequestHeaderHash(
challenge_hash,
coinbase,
signature,
bytes.fromhex(self.key_config["farmer_target"]),
response.proof,
)
yield OutboundMessage(NodeType.FULL_NODE, Message("request_header_hash", request2), Delivery.BROADCAST)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("request_header_hash", request2),
Delivery.BROADCAST,
)
@api_request
async def respond_header_signature(self, response: harvester_protocol.RespondHeaderSignature):
async def respond_header_signature(
self, response: harvester_protocol.RespondHeaderSignature
):
"""
Receives a signature on a block header hash, which is required for submitting
a block to the blockchain.
@ -148,26 +194,36 @@ class Farmer:
proof_of_space: bytes32 = self.harvester_responses_proofs[response.quality]
plot_pubkey = self.harvester_responses_proofs[response.quality].plot_pubkey
assert response.header_hash_signature.verify([Util.hash256(header_hash)],
[plot_pubkey])
assert response.header_hash_signature.verify(
[Util.hash256(header_hash)], [plot_pubkey]
)
pos_hash: bytes32 = proof_of_space.get_hash()
request = farmer_protocol.HeaderSignature(pos_hash, header_hash, response.header_hash_signature)
yield OutboundMessage(NodeType.FULL_NODE, Message("header_signature", request), Delivery.BROADCAST)
request = farmer_protocol.HeaderSignature(
pos_hash, header_hash, response.header_hash_signature
)
yield OutboundMessage(
NodeType.FULL_NODE, Message("header_signature", request), Delivery.BROADCAST
)
@api_request
async def respond_partial_proof(self, response: harvester_protocol.RespondPartialProof):
async def respond_partial_proof(
self, response: harvester_protocol.RespondPartialProof
):
"""
Receives a signature on the hash of the farmer payment target, which is used in a pool
share, to tell the pool where to pay the farmer.
"""
farmer_target_hash = sha256(bytes.fromhex(self.key_config['farmer_target'])).digest()
farmer_target_hash = sha256(
bytes.fromhex(self.key_config["farmer_target"])
).digest()
plot_pubkey = self.harvester_responses_proofs[response.quality].plot_pubkey
assert response.farmer_target_signature.verify([Util.hash256(farmer_target_hash)],
[plot_pubkey])
assert response.farmer_target_signature.verify(
[Util.hash256(farmer_target_hash)], [plot_pubkey]
)
# TODO: Send partial to pool
"""
@ -181,51 +237,84 @@ class Farmer:
"""
header_hash: bytes32 = response.header_hash
quality: bytes32 = self.harvester_responses_proof_hash_to_qual[response.pos_hash]
quality: bytes32 = self.harvester_responses_proof_hash_to_qual[
response.pos_hash
]
self.harvester_responses_header_hash[quality] = header_hash
# TODO: only send to the harvester who made the proof of space, not all plotters
request = harvester_protocol.RequestHeaderSignature(quality, header_hash)
yield OutboundMessage(NodeType.HARVESTER, Message("request_header_signature", request), Delivery.BROADCAST)
yield OutboundMessage(
NodeType.HARVESTER,
Message("request_header_signature", request),
Delivery.BROADCAST,
)
@api_request
async def proof_of_space_finalized(self, proof_of_space_finalized: farmer_protocol.ProofOfSpaceFinalized):
async def proof_of_space_finalized(
self, proof_of_space_finalized: farmer_protocol.ProofOfSpaceFinalized
):
"""
Full node notifies farmer that a proof of space has been completed. It gets added to the
challenges list at that height, and height is updated if necessary
"""
get_proofs: bool = False
if (proof_of_space_finalized.height >= self.current_height and
proof_of_space_finalized.challenge_hash not in self.seen_challenges):
if (
proof_of_space_finalized.height >= self.current_height
and proof_of_space_finalized.challenge_hash not in self.seen_challenges
):
# Only get proofs for new challenges, at a current or new height
get_proofs = True
if (proof_of_space_finalized.height > self.current_height):
if proof_of_space_finalized.height > self.current_height:
self.current_height = proof_of_space_finalized.height
# TODO: ask the pool for this information
coinbase: CoinbaseInfo = CoinbaseInfo(uint32(self.current_height + 1),
calculate_block_reward(self.current_height),
bytes.fromhex(self.key_config["pool_target"]))
coinbase: CoinbaseInfo = CoinbaseInfo(
uint32(self.current_height + 1),
calculate_block_reward(self.current_height),
bytes.fromhex(self.key_config["pool_target"]),
)
pool_sks: List[PrivateKey] = [PrivateKey.from_bytes(bytes.fromhex(ce))
for ce in self.key_config["pool_sks"]]
coinbase_signature: PrependSignature = pool_sks[0].sign_prepend(bytes(coinbase))
self.coinbase_rewards[uint32(self.current_height + 1)] = (coinbase, coinbase_signature)
pool_sks: List[PrivateKey] = [
PrivateKey.from_bytes(bytes.fromhex(ce))
for ce in self.key_config["pool_sks"]
]
coinbase_signature: PrependSignature = pool_sks[0].sign_prepend(
bytes(coinbase)
)
self.coinbase_rewards[uint32(self.current_height + 1)] = (
coinbase,
coinbase_signature,
)
log.info(f"\tCurrent height set to {self.current_height}")
self.seen_challenges.add(proof_of_space_finalized.challenge_hash)
if proof_of_space_finalized.height not in self.challenges:
self.challenges[proof_of_space_finalized.height] = [proof_of_space_finalized]
self.challenges[proof_of_space_finalized.height] = [
proof_of_space_finalized
]
else:
self.challenges[proof_of_space_finalized.height].append(proof_of_space_finalized)
self.challenge_to_height[proof_of_space_finalized.challenge_hash] = proof_of_space_finalized.height
self.challenges[proof_of_space_finalized.height].append(
proof_of_space_finalized
)
self.challenge_to_height[
proof_of_space_finalized.challenge_hash
] = proof_of_space_finalized.height
if get_proofs:
message = harvester_protocol.NewChallenge(proof_of_space_finalized.challenge_hash)
yield OutboundMessage(NodeType.HARVESTER, Message("new_challenge", message), Delivery.BROADCAST)
message = harvester_protocol.NewChallenge(
proof_of_space_finalized.challenge_hash
)
yield OutboundMessage(
NodeType.HARVESTER,
Message("new_challenge", message),
Delivery.BROADCAST,
)
@api_request
async def proof_of_space_arrived(self, proof_of_space_arrived: farmer_protocol.ProofOfSpaceArrived):
async def proof_of_space_arrived(
self, proof_of_space_arrived: farmer_protocol.ProofOfSpaceArrived
):
"""
Full node notifies the farmer that a new proof of space was created. The farmer can use this
information to decide whether to propagate a proof.
@ -233,10 +322,14 @@ class Farmer:
if proof_of_space_arrived.height not in self.unfinished_challenges:
self.unfinished_challenges[proof_of_space_arrived.height] = []
else:
self.unfinished_challenges[proof_of_space_arrived.height].append(proof_of_space_arrived.quality)
self.unfinished_challenges[proof_of_space_arrived.height].append(
proof_of_space_arrived.quality
)
@api_request
async def deep_reorg_notification(self, deep_reorg_notification: farmer_protocol.DeepReorgNotification):
async def deep_reorg_notification(
self, deep_reorg_notification: farmer_protocol.DeepReorgNotification
):
"""
Resets everything. This will be triggered when a long reorg happens, which means blocks of lower
height (but greater weight) might come.
@ -253,7 +346,9 @@ class Farmer:
self.coinbase_rewards = {}
@api_request
async def proof_of_time_rate(self, proof_of_time_rate: farmer_protocol.ProofOfTimeRate):
async def proof_of_time_rate(
self, proof_of_time_rate: farmer_protocol.ProofOfTimeRate
):
"""
Updates our internal etimate of the iterations per second for the fastest proof of time
in the network.

View File

@ -19,8 +19,7 @@ from src.consensus.pot_iterations import calculate_iterations
from src.consensus.weight_verifier import verify_weight
from src.db.database import FullNodeStore
from src.protocols import farmer_protocol, peer_protocol, timelord_protocol
from src.server.outbound_message import (Delivery, Message, NodeType,
OutboundMessage)
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.server.server import ChiaServer
from src.types.body import Body
from src.types.challenge import Challenge
@ -32,8 +31,11 @@ from src.types.peer_info import PeerInfo
from src.types.sized_bytes import bytes32
from src.util import errors
from src.util.api_decorators import api_request
from src.util.errors import (BlockNotInBlockchain, InvalidUnfinishedBlock,
PeersDontHaveBlock)
from src.util.errors import (
BlockNotInBlockchain,
InvalidUnfinishedBlock,
PeersDontHaveBlock,
)
from src.util.ints import uint32, uint64
log = logging.getLogger(__name__)
@ -53,8 +55,9 @@ class FullNode:
def _set_server(self, server: ChiaServer):
self.server = server
async def _send_tips_to_farmers(self, delivery: Delivery = Delivery.BROADCAST) -> \
AsyncGenerator[OutboundMessage, None]:
async def _send_tips_to_farmers(
self, delivery: Delivery = Delivery.BROADCAST
) -> AsyncGenerator[OutboundMessage, None]:
"""
Sends all of the current heads to all farmer peers. Also sends the latest
estimated proof of time rate, so farmer can calulate which proofs are good.
@ -67,19 +70,29 @@ class FullNode:
height = tip.challenge.height
quality = tip.proof_of_space.verify_and_get_quality()
if tip.height > 0:
difficulty: uint64 = await self.blockchain.get_next_difficulty(tip.prev_header_hash)
difficulty: uint64 = await self.blockchain.get_next_difficulty(
tip.prev_header_hash
)
else:
difficulty = tip.weight
requests.append(farmer_protocol.ProofOfSpaceFinalized(challenge_hash, height,
quality, difficulty))
requests.append(
farmer_protocol.ProofOfSpaceFinalized(
challenge_hash, height, quality, difficulty
)
)
proof_of_time_rate: uint64 = await self.store.get_proof_of_time_estimate_ips()
for request in requests:
yield OutboundMessage(NodeType.FARMER, Message("proof_of_space_finalized", request), delivery)
yield OutboundMessage(
NodeType.FARMER, Message("proof_of_space_finalized", request), delivery
)
rate_update = farmer_protocol.ProofOfTimeRate(proof_of_time_rate)
yield OutboundMessage(NodeType.FARMER, Message("proof_of_time_rate", rate_update), delivery)
yield OutboundMessage(
NodeType.FARMER, Message("proof_of_time_rate", rate_update), delivery
)
async def _send_challenges_to_timelords(self, delivery: Delivery = Delivery.BROADCAST) -> \
AsyncGenerator[OutboundMessage, None]:
async def _send_challenges_to_timelords(
self, delivery: Delivery = Delivery.BROADCAST
) -> AsyncGenerator[OutboundMessage, None]:
"""
Sends all of the current heads to all timelord peers.
"""
@ -88,10 +101,16 @@ class FullNode:
for head in self.blockchain.get_current_tips():
assert head.challenge
challenge_hash = head.challenge.get_hash()
requests.append(timelord_protocol.ChallengeStart(challenge_hash, head.challenge.total_weight))
requests.append(
timelord_protocol.ChallengeStart(
challenge_hash, head.challenge.total_weight
)
)
for request in requests:
yield OutboundMessage(NodeType.TIMELORD, Message("challenge_start", request), delivery)
yield OutboundMessage(
NodeType.TIMELORD, Message("challenge_start", request), delivery
)
async def _on_connect(self) -> AsyncGenerator[OutboundMessage, None]:
"""
@ -108,7 +127,9 @@ class FullNode:
blocks.append(block)
for block in blocks:
request = peer_protocol.Block(block)
yield OutboundMessage(NodeType.FULL_NODE, Message("block", request), Delivery.RESPOND)
yield OutboundMessage(
NodeType.FULL_NODE, Message("block", request), Delivery.RESPOND
)
# Update farmers and timelord with most recent information
async for msg in self._send_challenges_to_timelords(Delivery.RESPOND):
@ -117,8 +138,9 @@ class FullNode:
yield msg
def _num_needed_peers(self):
diff = self.config['target_peer_count'] - \
len(self.server.global_connections.get_full_node_connections())
diff = self.config["target_peer_count"] - len(
self.server.global_connections.get_full_node_connections()
)
return diff if diff >= 0 else 0
def _start_bg_tasks(self):
@ -126,20 +148,21 @@ class FullNode:
Start a background task connecting periodically to the introducer and
requesting the peer list.
"""
introducer = self.config['introducer_peer']
introducer_peerinfo = PeerInfo(introducer['host'], introducer['port'])
introducer = self.config["introducer_peer"]
introducer_peerinfo = PeerInfo(introducer["host"], introducer["port"])
async def introducer_client():
async def on_connect():
msg = Message("request_peers", peer_protocol.RequestPeers())
yield OutboundMessage(NodeType.INTRODUCER, msg,
Delivery.RESPOND)
yield OutboundMessage(NodeType.INTRODUCER, msg, Delivery.RESPOND)
while True:
if self._num_needed_peers():
if not await self.server.start_client(introducer_peerinfo,
on_connect):
if not await self.server.start_client(
introducer_peerinfo, on_connect
):
continue
await asyncio.sleep(self.config['introducer_connect_interval'])
await asyncio.sleep(self.config["introducer_connect_interval"])
self._bg_tasks.add(asyncio.create_task(introducer_client()))
@ -177,7 +200,9 @@ class FullNode:
highest_weight = block.header_block.challenge.total_weight
tip_block = block
tip_height = block.header_block.challenge.height
if highest_weight <= max([t.weight for t in self.blockchain.get_current_tips()]):
if highest_weight <= max(
[t.weight for t in self.blockchain.get_current_tips()]
):
log.info("Not performing sync, already caught up.")
await self.store.set_sync_mode(False)
await self.store.clear_sync_info()
@ -192,12 +217,22 @@ class FullNode:
total_time_slept = 0
headers: List[HeaderBlock] = []
while total_time_slept < timeout:
for start_height in range(0, tip_height + 1, self.config['max_headers_to_send']):
end_height = min(start_height + self.config['max_headers_to_send'], tip_height + 1)
request = peer_protocol.RequestHeaderBlocks(tip_block.header_block.header.get_hash(),
[uint64(h) for h in range(start_height, end_height)])
for start_height in range(
0, tip_height + 1, self.config["max_headers_to_send"]
):
end_height = min(
start_height + self.config["max_headers_to_send"], tip_height + 1
)
request = peer_protocol.RequestHeaderBlocks(
tip_block.header_block.header.get_hash(),
[uint64(h) for h in range(start_height, end_height)],
)
# TODO: should we ask the same peer as before, for the headers?
yield OutboundMessage(NodeType.FULL_NODE, Message("request_header_blocks", request), Delivery.RANDOM)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("request_header_blocks", request),
Delivery.RANDOM,
)
await asyncio.sleep(sleep_interval)
total_time_slept += sleep_interval
@ -209,16 +244,22 @@ class FullNode:
if await self.store.get_potential_header(uint32(height)) is None:
received_all_headers = False
break
local_headers.append(await self.store.get_potential_header(uint32(height)))
local_headers.append(
await self.store.get_potential_header(uint32(height))
)
if received_all_headers:
headers = local_headers
break
log.error(f"Downloaded headers up to tip height: {tip_height}")
if not verify_weight(tip_block.header_block, headers):
# TODO: ban peers that provided the invalid heads or proofs
raise errors.InvalidWeight(f"Weight of {tip_block.header_block.header.get_hash()} not valid.")
raise errors.InvalidWeight(
f"Weight of {tip_block.header_block.header.get_hash()} not valid."
)
log.error(f"Validated weight of headers. Downloaded {len(headers)} headers, tip height {tip_height}")
log.error(
f"Validated weight of headers. Downloaded {len(headers)} headers, tip height {tip_height}"
)
assert tip_height + 1 == len(headers)
async with self.store.lock:
@ -228,30 +269,51 @@ class FullNode:
for height in range(fork_point.height + 1, tip_height + 1):
# Only download from fork point (what we don't have)
async with self.store.lock:
have_block = await self.store.get_potential_heads_full_block(headers[height].header.get_hash()) is not None
have_block = (
await self.store.get_potential_heads_full_block(
headers[height].header.get_hash()
)
is not None
)
if not have_block:
request_sync = peer_protocol.RequestSyncBlocks(tip_block.header_block.header.header_hash,
[uint64(height)])
request_sync = peer_protocol.RequestSyncBlocks(
tip_block.header_block.header.header_hash, [uint64(height)]
)
async with self.store.lock:
await self.store.set_potential_blocks_received(uint32(height), Event())
await self.store.set_potential_blocks_received(
uint32(height), Event()
)
found = False
for _ in range(30):
yield OutboundMessage(NodeType.FULL_NODE, Message("request_sync_blocks", request_sync),
Delivery.RANDOM)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("request_sync_blocks", request_sync),
Delivery.RANDOM,
)
try:
await asyncio.wait_for((await self.store.get_potential_blocks_received(uint32(height))).wait(),
timeout=2)
await asyncio.wait_for(
(
await self.store.get_potential_blocks_received(
uint32(height)
)
).wait(),
timeout=2,
)
found = True
break
except concurrent.futures.TimeoutError:
log.info("Did not receive desired block")
if not found:
raise PeersDontHaveBlock(f"Did not receive desired block at height {height}")
raise PeersDontHaveBlock(
f"Did not receive desired block at height {height}"
)
async with self.store.lock:
# TODO: ban peers that provide bad blocks
if have_block:
block = await self.store.get_potential_head(headers[height].header.get_hash())
block = await self.store.get_potential_head(
headers[height].header.get_hash()
)
else:
block = await self.store.get_potential_block(uint32(height))
@ -259,11 +321,21 @@ class FullNode:
start = time.time()
result = await self.blockchain.receive_block(block)
if result == ReceiveBlockResult.INVALID_BLOCK or result == ReceiveBlockResult.DISCONNECTED_BLOCK:
if (
result == ReceiveBlockResult.INVALID_BLOCK
or result == ReceiveBlockResult.DISCONNECTED_BLOCK
):
raise RuntimeError(f"Invalid block {block.header_hash}")
log.info(f"Took {time.time() - start} seconds to validate and add block {block.height}.")
assert max([h.height for h in self.blockchain.get_current_tips()]) >= height
await self.store.set_proof_of_time_estimate_ips(await self.blockchain.get_next_ips(block.header_hash))
log.info(
f"Took {time.time() - start} seconds to validate and add block {block.height}."
)
assert (
max([h.height for h in self.blockchain.get_current_tips()])
>= height
)
await self.store.set_proof_of_time_estimate_ips(
await self.blockchain.get_next_ips(block.header_hash)
)
async with self.store.lock:
log.info(f"Finished sync up to height {tip_height}")
@ -277,19 +349,25 @@ class FullNode:
yield msg
@api_request
async def request_header_blocks(self, request: peer_protocol.RequestHeaderBlocks) \
-> AsyncGenerator[OutboundMessage, None]:
async def request_header_blocks(
self, request: peer_protocol.RequestHeaderBlocks
) -> AsyncGenerator[OutboundMessage, None]:
"""
A peer requests a list of header blocks, by height. Used for syncing or light clients.
"""
if len(request.heights) > self.config['max_headers_to_send']:
raise errors.TooManyheadersRequested(f"The max number of headers is {self.config['max_headers_to_send']},\
but requested {len(request.heights)}")
if len(request.heights) > self.config["max_headers_to_send"]:
raise errors.TooManyheadersRequested(
f"The max number of headers is {self.config['max_headers_to_send']},\
but requested {len(request.heights)}"
)
async with self.store.lock:
try:
headers: List[HeaderBlock] = await self.blockchain.get_header_blocks_by_height(request.heights,
request.tip_header_hash)
headers: List[
HeaderBlock
] = await self.blockchain.get_header_blocks_by_height(
request.heights, request.tip_header_hash
)
except KeyError:
return
except BlockNotInBlockchain as e:
@ -297,11 +375,14 @@ class FullNode:
return
response = peer_protocol.HeaderBlocks(request.tip_header_hash, headers)
yield OutboundMessage(NodeType.FULL_NODE, Message("header_blocks", response), Delivery.RESPOND)
yield OutboundMessage(
NodeType.FULL_NODE, Message("header_blocks", response), Delivery.RESPOND
)
@api_request
async def header_blocks(self, request: peer_protocol.HeaderBlocks) \
-> AsyncGenerator[OutboundMessage, None]:
async def header_blocks(
self, request: peer_protocol.HeaderBlocks
) -> AsyncGenerator[OutboundMessage, None]:
"""
Receive header blocks from a peer.
"""
@ -313,24 +394,34 @@ class FullNode:
yield _
@api_request
async def request_sync_blocks(self, request: peer_protocol.RequestSyncBlocks) \
-> AsyncGenerator[OutboundMessage, None]:
async def request_sync_blocks(
self, request: peer_protocol.RequestSyncBlocks
) -> AsyncGenerator[OutboundMessage, None]:
"""
Responsd to a peers request for syncing blocks.
"""
blocks: List[FullBlock] = []
async with self.store.lock:
tip_block: Optional[FullBlock] = await self.blockchain.get_block(request.tip_header_hash)
tip_block: Optional[FullBlock] = await self.blockchain.get_block(
request.tip_header_hash
)
if tip_block is not None:
if len(request.heights) > self.config['max_blocks_to_send']:
raise errors.TooManyheadersRequested(f"The max number of blocks is "
f"{self.config['max_blocks_to_send']},"
f"but requested {len(request.heights)}")
if len(request.heights) > self.config["max_blocks_to_send"]:
raise errors.TooManyheadersRequested(
f"The max number of blocks is "
f"{self.config['max_blocks_to_send']},"
f"but requested {len(request.heights)}"
)
try:
header_blocks: List[HeaderBlock] = await self.blockchain.get_header_blocks_by_height(
request.heights, request.tip_header_hash)
header_blocks: List[
HeaderBlock
] = await self.blockchain.get_header_blocks_by_height(
request.heights, request.tip_header_hash
)
for header_block in header_blocks:
fetched = await self.blockchain.get_block(header_block.header.get_hash())
fetched = await self.blockchain.get_block(
header_block.header.get_hash()
)
assert fetched
blocks.append(fetched)
except KeyError:
@ -341,13 +432,19 @@ class FullNode:
return
else:
# We don't have the blocks that the client is looking for
log.info(f"Peer requested tip {request.tip_header_hash} that we don't have")
log.info(
f"Peer requested tip {request.tip_header_hash} that we don't have"
)
return
response = Message("sync_blocks", peer_protocol.SyncBlocks(request.tip_header_hash, blocks))
response = Message(
"sync_blocks", peer_protocol.SyncBlocks(request.tip_header_hash, blocks)
)
yield OutboundMessage(NodeType.FULL_NODE, response, Delivery.RESPOND)
@api_request
async def sync_blocks(self, request: peer_protocol.SyncBlocks) -> AsyncGenerator[OutboundMessage, None]:
async def sync_blocks(
self, request: peer_protocol.SyncBlocks
) -> AsyncGenerator[OutboundMessage, None]:
"""
We have received the blocks that we needed for syncing. Add them to processing queue.
"""
@ -365,8 +462,9 @@ class FullNode:
yield _
@api_request
async def request_header_hash(self, request: farmer_protocol.RequestHeaderHash) \
-> AsyncGenerator[OutboundMessage, None]:
async def request_header_hash(
self, request: farmer_protocol.RequestHeaderHash
) -> AsyncGenerator[OutboundMessage, None]:
"""
Creates a block body and header, with the proof of space, coinbase, and fee targets provided
by the farmer, and sends the hash of the header data back to the farmer.
@ -374,9 +472,12 @@ class FullNode:
plot_seed: bytes32 = request.proof_of_space.get_plot_seed()
# Checks that the proof of space is valid
quality_string: bytes = Verifier().validate_proof(plot_seed, request.proof_of_space.size,
request.challenge_hash,
bytes(request.proof_of_space.proof))
quality_string: bytes = Verifier().validate_proof(
plot_seed,
request.proof_of_space.size,
request.challenge_hash,
bytes(request.proof_of_space.proof),
)
assert quality_string
async with self.store.lock:
@ -389,7 +490,9 @@ class FullNode:
target_head = head
if target_head is None:
# TODO: should we still allow the farmer to farm?
log.warning(f"Challenge hash: {request.challenge_hash} not in one of three heads")
log.warning(
f"Challenge hash: {request.challenge_hash} not in one of three heads"
)
return
# TODO: use mempool to grab best transactions, for the selected head
@ -402,8 +505,14 @@ class FullNode:
cost = uint64(0)
# Creates a block with transactions, coinbase, and fees
body: Body = Body(request.coinbase, request.coinbase_signature,
fees, aggregate_sig, transactions_generator, cost)
body: Body = Body(
request.coinbase,
request.coinbase_signature,
fees,
aggregate_sig,
transactions_generator,
cost,
)
# Creates the block header
prev_header_hash: bytes32 = target_head.header.get_hash()
@ -414,36 +523,56 @@ class FullNode:
proof_of_space_hash: bytes32 = request.proof_of_space.get_hash()
body_hash: Body = body.get_hash()
extension_data: bytes32 = bytes32([0] * 32)
block_header_data: HeaderData = HeaderData(prev_header_hash, timestamp,
filter_hash, proof_of_space_hash,
body_hash, extension_data)
block_header_data: HeaderData = HeaderData(
prev_header_hash,
timestamp,
filter_hash,
proof_of_space_hash,
body_hash,
extension_data,
)
block_header_data_hash: bytes32 = block_header_data.get_hash()
# self.stores this block so we can submit it to the blockchain after it's signed by harvester
await self.store.add_candidate_block(proof_of_space_hash, body, block_header_data, request.proof_of_space)
await self.store.add_candidate_block(
proof_of_space_hash, body, block_header_data, request.proof_of_space
)
message = farmer_protocol.HeaderHash(proof_of_space_hash, block_header_data_hash)
yield OutboundMessage(NodeType.FARMER, Message("header_hash", message), Delivery.RESPOND)
message = farmer_protocol.HeaderHash(
proof_of_space_hash, block_header_data_hash
)
yield OutboundMessage(
NodeType.FARMER, Message("header_hash", message), Delivery.RESPOND
)
@api_request
async def header_signature(self, header_signature: farmer_protocol.HeaderSignature) \
-> AsyncGenerator[OutboundMessage, None]:
async def header_signature(
self, header_signature: farmer_protocol.HeaderSignature
) -> AsyncGenerator[OutboundMessage, None]:
"""
Signature of header hash, by the harvester. This is enough to create an unfinished
block, which only needs a Proof of Time to be finished. If the signature is valid,
we call the unfinished_block routine.
"""
async with self.store.lock:
if (await self.store.get_candidate_block(header_signature.pos_hash)) is None:
log.warning(f"PoS hash {header_signature.pos_hash} not found in database")
if (
await self.store.get_candidate_block(header_signature.pos_hash)
) is None:
log.warning(
f"PoS hash {header_signature.pos_hash} not found in database"
)
return
# Verifies that we have the correct header and body self.stored
block_body, block_header_data, pos = await self.store.get_candidate_block(header_signature.pos_hash)
block_body, block_header_data, pos = await self.store.get_candidate_block(
header_signature.pos_hash
)
assert block_header_data.get_hash() == header_signature.header_hash
block_header: Header = Header(block_header_data, header_signature.header_signature)
block_header: Header = Header(
block_header_data, header_signature.header_signature
)
header: HeaderBlock = HeaderBlock(pos, None, None, block_header)
unfinished_block_obj: FullBlock = FullBlock(header, block_body)
@ -455,46 +584,65 @@ class FullNode:
# TIMELORD PROTOCOL
@api_request
async def proof_of_time_finished(self, request: timelord_protocol.ProofOfTimeFinished) -> \
AsyncGenerator[OutboundMessage, None]:
async def proof_of_time_finished(
self, request: timelord_protocol.ProofOfTimeFinished
) -> AsyncGenerator[OutboundMessage, None]:
"""
A proof of time, received by a peer timelord. We can use this to complete a block,
and call the block routine (which handles propagation and verification of blocks).
"""
async with self.store.lock:
dict_key = (request.proof.challenge_hash, request.proof.number_of_iterations)
dict_key = (
request.proof.challenge_hash,
request.proof.number_of_iterations,
)
unfinished_block_obj: Optional[FullBlock] = await self.store.get_unfinished_block(dict_key)
unfinished_block_obj: Optional[
FullBlock
] = await self.store.get_unfinished_block(dict_key)
if not unfinished_block_obj:
log.warning(f"Received a proof of time that we cannot use to complete a block {dict_key}")
log.warning(
f"Received a proof of time that we cannot use to complete a block {dict_key}"
)
return
prev_block: Optional[HeaderBlock] = await self.blockchain.get_header_block(
unfinished_block_obj.prev_header_hash)
difficulty: uint64 = await self.blockchain.get_next_difficulty(unfinished_block_obj.prev_header_hash)
unfinished_block_obj.prev_header_hash
)
difficulty: uint64 = await self.blockchain.get_next_difficulty(
unfinished_block_obj.prev_header_hash
)
assert prev_block
assert prev_block.challenge
challenge: Challenge = Challenge(request.proof.challenge_hash,
unfinished_block_obj.header_block.proof_of_space.get_hash(),
request.proof.output.get_hash(),
uint32(prev_block.challenge.height + 1),
uint64(prev_block.challenge.total_weight + difficulty),
uint64(prev_block.challenge.total_iters +
request.proof.number_of_iterations))
challenge: Challenge = Challenge(
request.proof.challenge_hash,
unfinished_block_obj.header_block.proof_of_space.get_hash(),
request.proof.output.get_hash(),
uint32(prev_block.challenge.height + 1),
uint64(prev_block.challenge.total_weight + difficulty),
uint64(
prev_block.challenge.total_iters + request.proof.number_of_iterations
),
)
new_header_block = HeaderBlock(unfinished_block_obj.header_block.proof_of_space,
request.proof,
challenge,
unfinished_block_obj.header_block.header)
new_full_block: FullBlock = FullBlock(new_header_block, unfinished_block_obj.body)
new_header_block = HeaderBlock(
unfinished_block_obj.header_block.proof_of_space,
request.proof,
challenge,
unfinished_block_obj.header_block.header,
)
new_full_block: FullBlock = FullBlock(
new_header_block, unfinished_block_obj.body
)
async for msg in self.block(peer_protocol.Block(new_full_block)):
yield msg
# PEER PROTOCOL
@api_request
async def new_proof_of_time(self, new_proof_of_time: peer_protocol.NewProofOfTime) \
-> AsyncGenerator[OutboundMessage, None]:
async def new_proof_of_time(
self, new_proof_of_time: peer_protocol.NewProofOfTime
) -> AsyncGenerator[OutboundMessage, None]:
"""
A proof of time, received by a peer full node. If we have the rest of the block,
we can complete it. Otherwise, we just verify and propagate the proof.
@ -502,8 +650,12 @@ class FullNode:
finish_block: bool = False
propagate_proof: bool = False
async with self.store.lock:
if (await self.store.get_unfinished_block((new_proof_of_time.proof.challenge_hash,
new_proof_of_time.proof.number_of_iterations))):
if await self.store.get_unfinished_block(
(
new_proof_of_time.proof.challenge_hash,
new_proof_of_time.proof.number_of_iterations,
)
):
finish_block = True
elif new_proof_of_time.proof.is_valid(constants["DISCRIMINANT_SIZE_BITS"]):
@ -513,12 +665,16 @@ class FullNode:
async for msg in self.proof_of_time_finished(request):
yield msg
if propagate_proof:
yield OutboundMessage(NodeType.FULL_NODE, Message("new_proof_of_time", new_proof_of_time),
Delivery.BROADCAST_TO_OTHERS)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("new_proof_of_time", new_proof_of_time),
Delivery.BROADCAST_TO_OTHERS,
)
@api_request
async def unfinished_block(self, unfinished_block: peer_protocol.UnfinishedBlock) \
-> AsyncGenerator[OutboundMessage, None]:
async def unfinished_block(
self, unfinished_block: peer_protocol.UnfinishedBlock
) -> AsyncGenerator[OutboundMessage, None]:
"""
We have received an unfinished block, either created by us, or from another peer.
We can validate it and if it's a good block, propagate it to other peers and
@ -528,28 +684,40 @@ class FullNode:
if not self.blockchain.is_child_of_head(unfinished_block.block):
return
if not await self.blockchain.validate_unfinished_block(unfinished_block.block):
if not await self.blockchain.validate_unfinished_block(
unfinished_block.block
):
raise InvalidUnfinishedBlock()
prev_block: Optional[HeaderBlock] = await self.blockchain.get_header_block(
unfinished_block.block.prev_header_hash)
unfinished_block.block.prev_header_hash
)
assert prev_block
assert prev_block.challenge
challenge_hash: bytes32 = prev_block.challenge.get_hash()
difficulty: uint64 = await self.blockchain.get_next_difficulty(
unfinished_block.block.header_block.prev_header_hash)
unfinished_block.block.header_block.prev_header_hash
)
vdf_ips: uint64 = await self.blockchain.get_next_ips(
unfinished_block.block.header_block.prev_header_hash)
unfinished_block.block.header_block.prev_header_hash
)
iterations_needed: uint64 = calculate_iterations(unfinished_block.block.header_block.proof_of_space,
difficulty, vdf_ips,
constants["MIN_BLOCK_TIME"])
iterations_needed: uint64 = calculate_iterations(
unfinished_block.block.header_block.proof_of_space,
difficulty,
vdf_ips,
constants["MIN_BLOCK_TIME"],
)
if await self.store.get_unfinished_block((challenge_hash, iterations_needed)):
if await self.store.get_unfinished_block(
(challenge_hash, iterations_needed)
):
return
expected_time: uint64 = uint64(int(iterations_needed / (await self.store.get_proof_of_time_estimate_ips())))
expected_time: uint64 = uint64(
int(iterations_needed / (await self.store.get_proof_of_time_estimate_ips()))
)
if expected_time > constants["PROPAGATION_DELAY_THRESHOLD"]:
log.info(f"Block is slow, expected {expected_time} seconds, waiting")
@ -557,35 +725,60 @@ class FullNode:
await asyncio.sleep(3)
async with self.store.lock:
leader: Tuple[uint32, uint64] = await self.store.get_unfinished_block_leader()
leader: Tuple[
uint32, uint64
] = await self.store.get_unfinished_block_leader()
if leader is None or unfinished_block.block.height > leader[0]:
log.info(f"This is the first block at height {unfinished_block.block.height}, so propagate.")
log.info(
f"This is the first block at height {unfinished_block.block.height}, so propagate."
)
# If this is the first block we see at this height, propagate
await self.store.set_unfinished_block_leader((unfinished_block.block.height, expected_time))
await self.store.set_unfinished_block_leader(
(unfinished_block.block.height, expected_time)
)
elif unfinished_block.block.height == leader[0]:
if expected_time > leader[1] + constants["PROPAGATION_THRESHOLD"]:
# If VDF is expected to finish X seconds later than the best, don't propagate
log.info(f"VDF will finish too late {expected_time} seconds, so don't propagate")
log.info(
f"VDF will finish too late {expected_time} seconds, so don't propagate"
)
return
elif expected_time < leader[1]:
log.info(f"New best unfinished block at height {unfinished_block.block.height}")
log.info(
f"New best unfinished block at height {unfinished_block.block.height}"
)
# If this will be the first block to finalize, update our leader
await self.store.set_unfinished_block_leader((leader[0], expected_time))
await self.store.set_unfinished_block_leader(
(leader[0], expected_time)
)
else:
# If we have seen an unfinished block at a greater or equal height, don't propagate
log.info(f"Unfinished block at old height, so don't propagate")
return
await self.store.add_unfinished_block((challenge_hash, iterations_needed), unfinished_block.block)
await self.store.add_unfinished_block(
(challenge_hash, iterations_needed), unfinished_block.block
)
timelord_request = timelord_protocol.ProofOfSpaceInfo(challenge_hash, iterations_needed)
timelord_request = timelord_protocol.ProofOfSpaceInfo(
challenge_hash, iterations_needed
)
yield OutboundMessage(NodeType.TIMELORD, Message("proof_of_space_info", timelord_request), Delivery.BROADCAST)
yield OutboundMessage(NodeType.FULL_NODE, Message("unfinished_block", unfinished_block),
Delivery.BROADCAST_TO_OTHERS)
yield OutboundMessage(
NodeType.TIMELORD,
Message("proof_of_space_info", timelord_request),
Delivery.BROADCAST,
)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("unfinished_block", unfinished_block),
Delivery.BROADCAST_TO_OTHERS,
)
@api_request
async def block(self, block: peer_protocol.Block) -> AsyncGenerator[OutboundMessage, None]:
async def block(
self, block: peer_protocol.Block
) -> AsyncGenerator[OutboundMessage, None]:
"""
Receive a full block from a peer full node (or ourselves).
"""
@ -598,27 +791,40 @@ class FullNode:
await self.store.add_potential_head(header_hash, block.block)
return
# Record our minimum height, and whether we have a full set of heads
least_height: uint32 = min([h.height for h in self.blockchain.get_current_tips()])
full_heads: bool = len(self.blockchain.get_current_tips()) == constants["NUMBER_OF_HEADS"]
least_height: uint32 = min(
[h.height for h in self.blockchain.get_current_tips()]
)
full_heads: bool = len(self.blockchain.get_current_tips()) == constants[
"NUMBER_OF_HEADS"
]
# Tries to add the block to the blockchain
added: ReceiveBlockResult = await self.blockchain.receive_block(block.block)
if added == ReceiveBlockResult.ALREADY_HAVE_BLOCK:
return
elif added == ReceiveBlockResult.INVALID_BLOCK:
log.warning(f"Block {header_hash} at height {block.block.height} is invalid.")
log.warning(
f"Block {header_hash} at height {block.block.height} is invalid."
)
return
elif added == ReceiveBlockResult.DISCONNECTED_BLOCK:
log.warning(f"Disconnected block {header_hash}")
async with self.store.lock:
tip_height = max([head.height for head in self.blockchain.get_current_tips()])
tip_height = max(
[head.height for head in self.blockchain.get_current_tips()]
)
if block.block.height > tip_height + self.config["sync_blocks_behind_threshold"]:
if (
block.block.height
> tip_height + self.config["sync_blocks_behind_threshold"]
):
async with self.store.lock:
await self.store.clear_sync_info()
await self.store.add_potential_head(header_hash, block.block)
log.info(f"We are too far behind this block. Our height is {tip_height} and block is at "
f"{block.block.height}")
log.info(
f"We are too far behind this block. Our height is {tip_height} and block is at "
f"{block.block.height}"
)
# Perform a sync if we have to
await self.store.set_sync_mode(True)
try:
@ -636,8 +842,10 @@ class FullNode:
return
elif block.block.height > tip_height + 1:
log.info(f"We are a few blocks behind, our height is {tip_height} and block is at "
f"{block.block.height} so we will request these blocks.")
log.info(
f"We are a few blocks behind, our height is {tip_height} and block is at "
f"{block.block.height} so we will request these blocks."
)
while True:
# TODO: download a few blocks and add them to chain
# prev_block_hash = block.block.header_block.header.data.prev_header_hash
@ -650,10 +858,16 @@ class FullNode:
deep_reorg: bool = (block.block.height < least_height) and full_heads
ips_changed: bool = False
async with self.store.lock:
log.info(f"Updated heads, new heights: {[b.height for b in self.blockchain.get_current_tips()]}")
log.info(
f"Updated heads, new heights: {[b.height for b in self.blockchain.get_current_tips()]}"
)
difficulty = await self.blockchain.get_next_difficulty(block.block.prev_header_hash)
next_vdf_ips = await self.blockchain.get_next_ips(block.block.header_hash)
difficulty = await self.blockchain.get_next_difficulty(
block.block.prev_header_hash
)
next_vdf_ips = await self.blockchain.get_next_ips(
block.block.header_hash
)
log.info(f"Difficulty {difficulty} IPS {next_vdf_ips}")
if next_vdf_ips != await self.store.get_proof_of_time_estimate_ips():
await self.store.set_proof_of_time_estimate_ips(next_vdf_ips)
@ -661,53 +875,79 @@ class FullNode:
if ips_changed:
rate_update = farmer_protocol.ProofOfTimeRate(next_vdf_ips)
log.error(f"Sending proof of time rate {next_vdf_ips}")
yield OutboundMessage(NodeType.FARMER, Message("proof_of_time_rate", rate_update),
Delivery.BROADCAST)
yield OutboundMessage(
NodeType.FARMER,
Message("proof_of_time_rate", rate_update),
Delivery.BROADCAST,
)
if deep_reorg:
reorg_msg = farmer_protocol.DeepReorgNotification()
yield OutboundMessage(NodeType.FARMER, Message("deep_reorg", reorg_msg), Delivery.BROADCAST)
yield OutboundMessage(
NodeType.FARMER,
Message("deep_reorg", reorg_msg),
Delivery.BROADCAST,
)
assert block.block.header_block.proof_of_time
assert block.block.header_block.challenge
pos_quality = block.block.header_block.proof_of_space.verify_and_get_quality()
pos_quality = (
block.block.header_block.proof_of_space.verify_and_get_quality()
)
farmer_request = farmer_protocol.ProofOfSpaceFinalized(block.block.header_block.challenge.get_hash(),
block.block.header_block.challenge.height,
pos_quality,
difficulty)
timelord_request = timelord_protocol.ChallengeStart(block.block.header_block.challenge.get_hash(),
block.block.header_block.challenge.total_weight)
farmer_request = farmer_protocol.ProofOfSpaceFinalized(
block.block.header_block.challenge.get_hash(),
block.block.header_block.challenge.height,
pos_quality,
difficulty,
)
timelord_request = timelord_protocol.ChallengeStart(
block.block.header_block.challenge.get_hash(),
block.block.header_block.challenge.total_weight,
)
# Tell timelord to stop previous challenge and start with new one
yield OutboundMessage(NodeType.TIMELORD, Message("challenge_start", timelord_request), Delivery.BROADCAST)
yield OutboundMessage(
NodeType.TIMELORD,
Message("challenge_start", timelord_request),
Delivery.BROADCAST,
)
# Tell full nodes about the new block
yield OutboundMessage(NodeType.FULL_NODE, Message("block", block), Delivery.BROADCAST_TO_OTHERS)
yield OutboundMessage(
NodeType.FULL_NODE,
Message("block", block),
Delivery.BROADCAST_TO_OTHERS,
)
# Tell farmer about the new block
yield OutboundMessage(NodeType.FARMER, Message("proof_of_space_finalized", farmer_request),
Delivery.BROADCAST)
yield OutboundMessage(
NodeType.FARMER,
Message("proof_of_space_finalized", farmer_request),
Delivery.BROADCAST,
)
elif added == ReceiveBlockResult.ADDED_AS_ORPHAN:
assert block.block.header_block.proof_of_time
assert block.block.header_block.challenge
log.info(f"Received orphan block of height {block.block.header_block.challenge.height}")
log.info(
f"Received orphan block of height {block.block.header_block.challenge.height}"
)
else:
# Should never reach here, all the cases are covered
assert False
@api_request
async def peers(self, request: peer_protocol.Peers) -> \
AsyncGenerator[OutboundMessage, None]:
async def peers(
self, request: peer_protocol.Peers
) -> AsyncGenerator[OutboundMessage, None]:
conns = self.server.global_connections
log.info(f"Received peer list: {request.peer_list}")
for peer in request.peer_list:
conns.peers.add(peer)
# Pseudo-message to close the connection
yield OutboundMessage(NodeType.INTRODUCER, Message('', None),
Delivery.CLOSE)
yield OutboundMessage(NodeType.INTRODUCER, Message("", None), Delivery.CLOSE)
unconnected = conns.get_unconnected_peers()
to_connect = unconnected[:self._num_needed_peers()]
to_connect = unconnected[: self._num_needed_peers()]
if not len(to_connect):
return

View File

@ -9,8 +9,7 @@ from yaml import safe_load
from chiapos import DiskProver
from definitions import ROOT_DIR
from src.protocols import harvester_protocol
from src.server.outbound_message import (Delivery, Message, NodeType,
OutboundMessage)
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.types.proof_of_space import ProofOfSpace
from src.types.sized_bytes import bytes32
from src.util.api_decorators import api_request
@ -26,9 +25,13 @@ class Harvester:
key_config_filename = os.path.join(ROOT_DIR, "src", "config", "keys.yaml")
if not os.path.isfile(key_config_filename):
raise RuntimeError("Keys not generated. Run ./src/scripts/regenerate_keys.py.")
raise RuntimeError(
"Keys not generated. Run ./src/scripts/regenerate_keys.py."
)
if not os.path.isfile(plot_config_filename):
raise RuntimeError("Plots not generated. Run ./src/scripts/create_plots.py.")
raise RuntimeError(
"Plots not generated. Run ./src/scripts/create_plots.py."
)
self.config = safe_load(open(config_filename, "r"))["harvester"]
self.key_config = safe_load(open(key_config_filename, "r"))
@ -41,18 +44,20 @@ class Harvester:
self.challenge_hashes: Dict[bytes32, Tuple[bytes32, str, uint8]] = {}
@api_request
async def harvester_handshake(self, harvester_handshake: harvester_protocol.HarvesterHandshake):
async def harvester_handshake(
self, harvester_handshake: harvester_protocol.HarvesterHandshake
):
"""
Handshake between the harvester and farmer. The harvester receives the pool public keys,
which must be put into the plots, before the plotting process begins. We cannot
use any plots which don't have one of the pool keys.
"""
for partial_filename, plot_config in self.plot_config['plots'].items():
for partial_filename, plot_config in self.plot_config["plots"].items():
if "plot_root" in self.config:
filename = os.path.join(self.config["plot_root"], partial_filename)
else:
filename = os.path.join(ROOT_DIR, "plots", partial_filename)
pool_pubkey = PublicKey.from_bytes(bytes.fromhex(plot_config['pool_pk']))
pool_pubkey = PublicKey.from_bytes(bytes.fromhex(plot_config["pool_pk"]))
# Only use plots that correct pools associated with them
if pool_pubkey in harvester_handshake.pool_pubkeys:
@ -62,7 +67,9 @@ class Harvester:
log.warn(f"Plot at {filename} does not exist.")
else:
log.warning(f"Plot {filename} has a pool key that is not in the farmer's pool_pk list.")
log.warning(
f"Plot {filename} has a pool key that is not in the farmer's pool_pk list."
)
@api_request
async def new_challenge(self, new_challenge: harvester_protocol.NewChallenge):
@ -77,25 +84,39 @@ class Harvester:
all_responses = []
for filename, prover in self.provers.items():
try:
quality_strings = prover.get_qualities_for_challenge(new_challenge.challenge_hash)
quality_strings = prover.get_qualities_for_challenge(
new_challenge.challenge_hash
)
except RuntimeError:
log.warning("Error using prover object. Reinitializing prover object.")
self.provers[filename] = DiskProver(filename)
quality_strings = prover.get_qualities_for_challenge(new_challenge.challenge_hash)
quality_strings = prover.get_qualities_for_challenge(
new_challenge.challenge_hash
)
for index, quality_str in enumerate(quality_strings):
quality = ProofOfSpace.quality_str_to_quality(new_challenge.challenge_hash, quality_str)
self.challenge_hashes[quality] = (new_challenge.challenge_hash, filename, uint8(index))
response: harvester_protocol.ChallengeResponse = harvester_protocol.ChallengeResponse(
quality = ProofOfSpace.quality_str_to_quality(
new_challenge.challenge_hash, quality_str
)
self.challenge_hashes[quality] = (
new_challenge.challenge_hash,
quality,
prover.get_size()
filename,
uint8(index),
)
response: harvester_protocol.ChallengeResponse = harvester_protocol.ChallengeResponse(
new_challenge.challenge_hash, quality, prover.get_size()
)
all_responses.append(response)
for response in all_responses:
yield OutboundMessage(NodeType.FARMER, Message("challenge_response", response), Delivery.RESPOND)
yield OutboundMessage(
NodeType.FARMER,
Message("challenge_response", response),
Delivery.RESPOND,
)
@api_request
async def request_proof_of_space(self, request: harvester_protocol.RequestProofOfSpace):
async def request_proof_of_space(
self, request: harvester_protocol.RequestProofOfSpace
):
"""
The farmer requests a signature on the header hash, for one of the proofs that we found.
We look up the correct plot based on the quality, lookup the proof, and return it.
@ -115,24 +136,34 @@ class Harvester:
self.provers[filename] = DiskProver(filename)
proof_xs = self.provers[filename].get_full_proof(challenge_hash, index)
pool_pubkey = PublicKey.from_bytes(bytes.fromhex(self.plot_config['plots'][filename]['pool_pk']))
plot_pubkey = PrivateKey.from_bytes(bytes.fromhex(self.plot_config['plots'][filename]['sk'])) \
.get_public_key()
proof_of_space: ProofOfSpace = ProofOfSpace(challenge_hash,
pool_pubkey,
plot_pubkey,
uint8(self.provers[filename].get_size()),
[uint8(b) for b in proof_xs])
pool_pubkey = PublicKey.from_bytes(
bytes.fromhex(self.plot_config["plots"][filename]["pool_pk"])
)
plot_pubkey = PrivateKey.from_bytes(
bytes.fromhex(self.plot_config["plots"][filename]["sk"])
).get_public_key()
proof_of_space: ProofOfSpace = ProofOfSpace(
challenge_hash,
pool_pubkey,
plot_pubkey,
uint8(self.provers[filename].get_size()),
[uint8(b) for b in proof_xs],
)
response = harvester_protocol.RespondProofOfSpace(
request.quality,
proof_of_space
request.quality, proof_of_space
)
if response:
yield OutboundMessage(NodeType.FARMER, Message("respond_proof_of_space", response), Delivery.RESPOND)
yield OutboundMessage(
NodeType.FARMER,
Message("respond_proof_of_space", response),
Delivery.RESPOND,
)
@api_request
async def request_header_signature(self, request: harvester_protocol.RequestHeaderSignature):
async def request_header_signature(
self, request: harvester_protocol.RequestHeaderSignature
):
"""
The farmer requests a signature on the header hash, for one of the proofs that we found.
A signature is created on the header hash using the plot private key.
@ -140,29 +171,47 @@ class Harvester:
_, filename, _ = self.challenge_hashes[request.quality]
plot_sk = PrivateKey.from_bytes(bytes.fromhex(self.plot_config['plots'][filename]['sk']))
header_hash_signature: PrependSignature = plot_sk.sign_prepend(request.header_hash)
assert(header_hash_signature.verify([Util.hash256(request.header_hash)], [plot_sk.get_public_key()]))
plot_sk = PrivateKey.from_bytes(
bytes.fromhex(self.plot_config["plots"][filename]["sk"])
)
header_hash_signature: PrependSignature = plot_sk.sign_prepend(
request.header_hash
)
assert header_hash_signature.verify(
[Util.hash256(request.header_hash)], [plot_sk.get_public_key()]
)
response: harvester_protocol.RespondHeaderSignature = harvester_protocol.RespondHeaderSignature(
request.quality,
header_hash_signature,
request.quality, header_hash_signature,
)
yield OutboundMessage(
NodeType.FARMER,
Message("respond_header_signature", response),
Delivery.RESPOND,
)
yield OutboundMessage(NodeType.FARMER, Message("respond_header_signature", response), Delivery.RESPOND)
@api_request
async def request_partial_proof(self, request: harvester_protocol.RequestPartialProof):
async def request_partial_proof(
self, request: harvester_protocol.RequestPartialProof
):
"""
The farmer requests a signature on the farmer_target, for one of the proofs that we found.
We look up the correct plot based on the quality, lookup the proof, and sign
the farmer target hash using the plot private key. This will be used as a pool share.
"""
_, filename, _ = self.challenge_hashes[request.quality]
plot_sk = PrivateKey.from_bytes(bytes.fromhex(self.plot_config['plots'][filename]['sk']))
farmer_target_signature: PrependSignature = plot_sk.sign_prepend(request.farmer_target_hash)
plot_sk = PrivateKey.from_bytes(
bytes.fromhex(self.plot_config["plots"][filename]["sk"])
)
farmer_target_signature: PrependSignature = plot_sk.sign_prepend(
request.farmer_target_hash
)
response: harvester_protocol.RespondPartialProof = harvester_protocol.RespondPartialProof(
request.quality,
farmer_target_signature
request.quality, farmer_target_signature
)
yield OutboundMessage(
NodeType.FARMER,
Message("respond_partial_proof", response),
Delivery.RESPOND,
)
yield OutboundMessage(NodeType.FARMER, Message("respond_partial_proof", response), Delivery.RESPOND)

View File

@ -5,8 +5,7 @@ import yaml
from definitions import ROOT_DIR
from src.protocols.peer_protocol import Peers, RequestPeers
from src.server.outbound_message import (Delivery, Message, NodeType,
OutboundMessage)
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.server.server import ChiaServer
from src.util.api_decorators import api_request
@ -20,9 +19,10 @@ class Introducer:
self.server = server
@api_request
async def request_peers(self, request: RequestPeers) \
-> AsyncGenerator[OutboundMessage, None]:
max_peers = self.config['max_peers_to_send']
async def request_peers(
self, request: RequestPeers
) -> AsyncGenerator[OutboundMessage, None]:
max_peers = self.config["max_peers_to_send"]
peers = self.server.global_connections.peers.get_peers(max_peers, True)
msg = Message("peers", Peers(peers))
yield OutboundMessage(NodeType.FULL_NODE, msg, Delivery.RESPOND)

View File

@ -22,6 +22,7 @@ class TransactionId:
"""
Receive a transaction id from a peer.
"""
transaction_id: bytes32
@ -31,6 +32,7 @@ class RequestTransaction:
"""
Request a transaction from a peer.
"""
transaction_id: bytes32
@ -40,6 +42,7 @@ class NewTransaction:
"""
Receive a transaction from a peer.
"""
transaction: Transaction
@ -49,6 +52,7 @@ class NewProofOfTime:
"""
Receive a new proof of time from a peer.
"""
proof: ProofOfTime
@ -58,6 +62,7 @@ class UnfinishedBlock:
"""
Receive an unfinished block from a peer.
"""
# Block that does not have ProofOfTime and Challenge
block: FullBlock
@ -68,6 +73,7 @@ class RequestBlock:
"""
Requests a block from a peer.
"""
header_hash: bytes32
@ -77,6 +83,7 @@ class Block:
"""
Receive a block from a peer.
"""
block: FullBlock
@ -94,6 +101,7 @@ class Peers:
"""
Update list of peers
"""
peer_list: List[PeerInfo]
@ -103,6 +111,7 @@ class RequestHeaderBlocks:
"""
Request headers of blocks that are ancestors of the specified tip.
"""
tip_header_hash: bytes32
heights: List[uint64]
@ -113,6 +122,7 @@ class HeaderBlocks:
"""
Sends header blocks that are ancestors of the specified tip, at the specified heights.
"""
tip_header_hash: bytes32
header_blocks: List[HeaderBlock]
@ -123,6 +133,7 @@ class RequestSyncBlocks:
"""
Request download of blocks, in the blockchain that has 'tip_header_hash' as the tip
"""
tip_header_hash: bytes32
heights: List[uint64]
@ -133,5 +144,6 @@ class SyncBlocks:
"""
Send blocks to peer.
"""
tip_header_hash: bytes32
blocks: List[FullBlock]

View File

@ -10,6 +10,8 @@ protocol_version = "0.0.2"
"""
Handshake when establishing a connection between two servers.
"""
@dataclass(frozen=True)
@cbor_message
class Handshake:

View File

@ -15,6 +15,8 @@ If don't have the unfinished block, ignore
Validate PoT
Call self.Block
"""
@dataclass(frozen=True)
@cbor_message
class ProofOfTimeFinished:

View File

@ -20,12 +20,14 @@ def main():
Script for creating plots and adding them to the plot config file.
"""
parser = argparse.ArgumentParser(
description="Chia plotting script."
)
parser = argparse.ArgumentParser(description="Chia plotting script.")
parser.add_argument("-k", "--size", help="Plot size", type=int, default=20)
parser.add_argument("-n", "--num_plots", help="Number of plots", type=int, default=10)
parser.add_argument("-p", "--pool_pub_key", help="Hex public key of pool", type=str, default="")
parser.add_argument(
"-n", "--num_plots", help="Number of plots", type=int, default=10
)
parser.add_argument(
"-p", "--pool_pub_key", help="Hex public key of pool", type=str, default=""
)
# We need the keys file, to access pool keys (if the exist), and the sk_seed.
args = parser.parse_args()
@ -45,14 +47,20 @@ def main():
pool_sk = PrivateKey.from_bytes(bytes.fromhex(key_config["pool_sks"][0]))
pool_pk = pool_sk.get_public_key()
print(f"Creating {args.num_plots} plots of size {args.size}, sk_seed {sk_seed.hex()} ppk {pool_pk}")
print(
f"Creating {args.num_plots} plots of size {args.size}, sk_seed {sk_seed.hex()} ppk {pool_pk}"
)
for i in range(args.num_plots):
# Generate a sk based on the seed, plot size (k), and index
sk: PrivateKey = PrivateKey.from_seed(sk_seed + args.size.to_bytes(1, "big") + i.to_bytes(4, "big"))
sk: PrivateKey = PrivateKey.from_seed(
sk_seed + args.size.to_bytes(1, "big") + i.to_bytes(4, "big")
)
# The plot seed is based on the pool and plot pks
plot_seed: bytes32 = ProofOfSpace.calculate_plot_seed(pool_pk, sk.get_public_key())
plot_seed: bytes32 = ProofOfSpace.calculate_plot_seed(
pool_pk, sk.get_public_key()
)
filename: str = f"plot-{i}-{args.size}-{plot_seed}.dat"
full_path: str = os.path.join(plot_root, filename)
if os.path.isfile(full_path):
@ -71,7 +79,7 @@ def main():
if filename not in plot_config_plots_new:
plot_config_plots_new[filename] = {
"sk": bytes(sk).hex(),
"pool_pk": bytes(pool_pk).hex()
"pool_pk": bytes(pool_pk).hex(),
}
plot_config["plots"].update(plot_config_plots_new)

View File

@ -14,12 +14,12 @@ key_config_filename = os.path.join(ROOT_DIR, "src", "config", "keys.yaml")
def str2bool(v: str) -> bool:
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
raise argparse.ArgumentTypeError("Boolean value expected.")
def main():
@ -27,21 +27,42 @@ def main():
Allows replacing keys of farmer, harvester, and pool, all default to True.
"""
parser = argparse.ArgumentParser(
description="Chia key generator script."
parser = argparse.ArgumentParser(description="Chia key generator script.")
parser.add_argument(
"-f",
"--farmer",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Regenerate farmer key",
)
parser.add_argument(
"-a",
"--harvester",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Regenerate plot key seed",
)
parser.add_argument(
"-p",
"--pool",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Regenerate pool keys",
)
parser.add_argument("-f", "--farmer", type=str2bool, nargs='?', const=True, default=True,
help="Regenerate farmer key")
parser.add_argument("-a", "--harvester", type=str2bool, nargs='?', const=True, default=True,
help="Regenerate plot key seed")
parser.add_argument("-p", "--pool", type=str2bool, nargs='?', const=True, default=True,
help="Regenerate pool keys")
args = parser.parse_args()
if os.path.isfile(key_config_filename):
# If the file exists, warn the user
yn = input(f"The keys file {key_config_filename} already exists. Are you sure"
f" you want to override the keys? Plots might become invalid. (y/n): ")
yn = input(
f"The keys file {key_config_filename} already exists. Are you sure"
f" you want to override the keys? Plots might become invalid. (y/n): "
)
if not (yn.lower() == "y" or yn.lower() == "yes"):
quit()
else:

View File

@ -20,8 +20,15 @@ class Connection:
port are the host and port of the peer that we are connected to. Node_id and connection_type are
set after the handshake is performed in this connection.
"""
def __init__(self, local_type: NodeType, connection_type: Optional[NodeType], sr: StreamReader,
sw: StreamWriter, server_port: int):
def __init__(
self,
local_type: NodeType,
connection_type: Optional[NodeType],
sr: StreamReader,
sw: StreamWriter,
server_port: int,
):
self.local_type = local_type
self.connection_type = connection_type
self.reader = sr
@ -47,24 +54,23 @@ class Connection:
return self.writer.get_extra_info("socket")
def get_peer_info(self) -> Optional[PeerInfo]:
if not self.peer_server_port or \
self.connection_type != NodeType.FULL_NODE:
if not self.peer_server_port or self.connection_type != NodeType.FULL_NODE:
return None
return PeerInfo(self.peer_host, uint16(self.peer_server_port))
async def send(self, message: Message):
encoded: bytes = cbor.dumps({"f": message.function, "d": message.data})
assert(len(encoded) < (2**(LENGTH_BYTES*8)))
assert len(encoded) < (2 ** (LENGTH_BYTES * 8))
self.writer.write(len(encoded).to_bytes(LENGTH_BYTES, "big") + encoded)
await self.writer.drain()
self.bytes_written += (LENGTH_BYTES + len(encoded))
self.bytes_written += LENGTH_BYTES + len(encoded)
async def read_one_message(self) -> Message:
size = await self.reader.readexactly(LENGTH_BYTES)
full_message_length = int.from_bytes(size, "big")
full_message: bytes = await self.reader.readexactly(full_message_length)
full_message_loaded: Any = cbor.loads(full_message)
self.bytes_read += (LENGTH_BYTES + full_message_length)
self.bytes_read += LENGTH_BYTES + full_message_length
return Message(full_message_loaded["f"], full_message_loaded["d"])
def close(self):
@ -117,8 +123,7 @@ class PeerConnections:
return list(filter(Connection.get_peer_info, self._all_connections))
def get_full_node_peerinfos(self):
return list(filter(None, map(Connection.get_peer_info,
self._all_connections)))
return list(filter(None, map(Connection.get_peer_info, self._all_connections)))
def get_unconnected_peers(self, max_peers=0):
connected = self.get_full_node_peerinfos()
@ -134,6 +139,7 @@ class Peers:
Has the list of known full node peers that are already connected or may be
connected to.
"""
def __init__(self):
self._peers: List[PeerInfo] = []
@ -152,8 +158,7 @@ class Peers:
except ValueError:
return False
def get_peers(self, max_peers: int = 0, randomize: bool = False) \
-> List[PeerInfo]:
def get_peers(self, max_peers: int = 0, randomize: bool = False) -> List[PeerInfo]:
if not max_peers or max_peers > len(self._peers):
max_peers = len(self._peers)
if randomize:

View File

@ -3,18 +3,19 @@ import logging
import random
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple
from aiter import (aiter_forker, iter_to_aiter, join_aiters, map_aiter,
push_aiter)
from aiter import aiter_forker, iter_to_aiter, join_aiters, map_aiter, push_aiter
from aiter.server import start_server_aiter
from src.protocols.shared_protocol import (Handshake, HandshakeAck,
protocol_version)
from src.protocols.shared_protocol import Handshake, HandshakeAck, protocol_version
from src.server.connection import Connection, PeerConnections
from src.server.outbound_message import (Delivery, Message, NodeType,
OutboundMessage)
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.types.peer_info import PeerInfo
from src.util import partial_func
from src.util.errors import (IncompatibleProtocolVersion, InvalidAck,
InvalidHandshake, InvalidProtocolMessage)
from src.util.errors import (
IncompatibleProtocolVersion,
InvalidAck,
InvalidHandshake,
InvalidProtocolMessage,
)
from src.util.ints import uint16
from src.util.network import create_node_id
@ -54,11 +55,18 @@ class ChiaServer:
self._local_type = local_type # NodeType (farmer, full node, timelord, pool, harvester, wallet)
self._srwt_aiter = push_aiter()
self._outbound_aiter = push_aiter()
self._pipeline_task = self.initialize_pipeline(self._srwt_aiter, self._api, self._port)
self._pipeline_task = self.initialize_pipeline(
self._srwt_aiter, self._api, self._port
)
self._node_id = create_node_id()
async def start_server(self, host: str,
on_connect: Optional[Callable[[], AsyncGenerator[OutboundMessage, None]]] = None) -> bool:
async def start_server(
self,
host: str,
on_connect: Optional[
Callable[[], AsyncGenerator[OutboundMessage, None]]
] = None,
) -> bool:
"""
Launches a listening server on host and port specified, to connect to NodeType nodes. On each
connection, the on_connect asynchronous generator will be called, and responses will be sent.
@ -68,13 +76,17 @@ class ChiaServer:
return False
self._host = host
self._server, aiter = await start_server_aiter(self._port, host=None, reuse_address=True)
self._server, aiter = await start_server_aiter(
self._port, host=None, reuse_address=True
)
if on_connect is not None:
self._on_connect_generic_callback = on_connect
def add_connection_type(srw: Tuple[asyncio.StreamReader, asyncio.StreamWriter]) -> \
Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
def add_connection_type(
srw: Tuple[asyncio.StreamReader, asyncio.StreamWriter]
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
return (srw[0], srw[1])
srwt_aiter = map_aiter(add_connection_type, aiter)
# Push all aiters that come from the server, into the pipeline
@ -83,8 +95,13 @@ class ChiaServer:
log.info(f"Server started on port {self._port}")
return True
async def start_client(self, target_node: PeerInfo,
on_connect: Optional[Callable[[], AsyncGenerator[OutboundMessage, None]]] = None) -> bool:
async def start_client(
self,
target_node: PeerInfo,
on_connect: Optional[
Callable[[], AsyncGenerator[OutboundMessage, None]]
] = None,
) -> bool:
"""
Tries to connect to the target node, adding one connection into the pipeline, if successful.
An on connect method can also be specified, and this will be saved into the instance variables.
@ -99,7 +116,9 @@ class ChiaServer:
if self._pipeline_task.done():
return False
try:
reader, writer = await asyncio.open_connection(target_node.host, int(target_node.port))
reader, writer = await asyncio.open_connection(
target_node.host, int(target_node.port)
)
succeeded = True
break
except (ConnectionRefusedError, TimeoutError, OSError) as e:
@ -117,7 +136,10 @@ class ChiaServer:
asyncio.create_task(self._add_to_srwt_aiter(iter_to_aiter([(reader, writer)])))
return True
async def _add_to_srwt_aiter(self, aiter: AsyncGenerator[Tuple[asyncio.StreamReader, asyncio.StreamWriter], None]):
async def _add_to_srwt_aiter(
self,
aiter: AsyncGenerator[Tuple[asyncio.StreamReader, asyncio.StreamWriter], None],
):
"""
Adds all swrt from aiter into the instance variable srwt_aiter, adding them to the pipeline.
"""
@ -156,36 +178,53 @@ class ChiaServer:
"""
# Maps a stream reader, writer and NodeType to a Connection object
connections_aiter = map_aiter(partial_func.partial_async(self.stream_reader_writer_to_connection,
server_port), aiter)
connections_aiter = map_aiter(
partial_func.partial_async(
self.stream_reader_writer_to_connection, server_port
),
aiter,
)
# Performs a handshake with the peer
handshaked_connections_aiter = join_aiters(map_aiter(self.perform_handshake, connections_aiter))
handshaked_connections_aiter = join_aiters(
map_aiter(self.perform_handshake, connections_aiter)
)
forker = aiter_forker(handshaked_connections_aiter)
handshake_finished_1 = forker.fork(is_active=True)
handshake_finished_2 = forker.fork(is_active=True)
# Reads messages one at a time from the TCP connection
messages_aiter = join_aiters(map_aiter(self.connection_to_message, handshake_finished_1, 100))
messages_aiter = join_aiters(
map_aiter(self.connection_to_message, handshake_finished_1, 100)
)
# Handles each message one at a time, and yields responses to send back or broadcast
responses_aiter = join_aiters(map_aiter(
partial_func.partial_async_gen(self.handle_message, api),
messages_aiter, 100))
responses_aiter = join_aiters(
map_aiter(
partial_func.partial_async_gen(self.handle_message, api),
messages_aiter,
100,
)
)
# Uses a forked aiter, and calls the on_connect function to send some initial messages
# as soon as the connection is established
on_connect_outbound_aiter = join_aiters(map_aiter(self.connection_to_outbound,
handshake_finished_2, 100))
on_connect_outbound_aiter = join_aiters(
map_aiter(self.connection_to_outbound, handshake_finished_2, 100)
)
# Also uses the instance variable _outbound_aiter, which clients can use to send messages
# at any time, not just on_connect.
outbound_aiter_mapped = map_aiter(lambda x: (None, x), self._outbound_aiter)
responses_aiter = join_aiters(iter_to_aiter([responses_aiter, on_connect_outbound_aiter,
outbound_aiter_mapped]))
responses_aiter = join_aiters(
iter_to_aiter(
[responses_aiter, on_connect_outbound_aiter, outbound_aiter_mapped]
)
)
# For each outbound message, replicate for each peer that we need to send to
expanded_messages_aiter = join_aiters(map_aiter(
self.expand_outbound_messages, responses_aiter, 100))
expanded_messages_aiter = join_aiters(
map_aiter(self.expand_outbound_messages, responses_aiter, 100)
)
# This will run forever. Sends each message through the TCP connection, using the
# length encoding and CBOR serialization
@ -195,15 +234,17 @@ class ChiaServer:
try:
await connection.send(message)
except (ConnectionResetError, BrokenPipeError) as e:
log.error(f"Cannot write to {connection}, already closed. Error {e}.")
log.error(
f"Cannot write to {connection}, already closed. Error {e}."
)
# We will return a task for this, so user of start_chia_server or start_chia_client can wait until
# the server is closed.
return asyncio.get_running_loop().create_task(serve_forever())
async def stream_reader_writer_to_connection(self,
swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter],
server_port: int) -> Connection:
async def stream_reader_writer_to_connection(
self, swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter], server_port: int
) -> Connection:
"""
Maps a pair of (StreamReader, StreamWriter) to a Connection object,
which also stores the type of connection (str). It is also added to the global list.
@ -214,8 +255,9 @@ class ChiaServer:
log.info(f"Connection with {con.get_peername()} established")
return con
async def connection_to_outbound(self, connection: Connection) -> AsyncGenerator[
Tuple[Connection, OutboundMessage], None]:
async def connection_to_outbound(
self, connection: Connection
) -> AsyncGenerator[Tuple[Connection, OutboundMessage], None]:
"""
Async generator which calls the on_connect async generator method, and yields any outbound messages.
"""
@ -228,14 +270,21 @@ class ChiaServer:
async for outbound_message in self._on_connect_generic_callback():
yield connection, outbound_message
async def perform_handshake(self, connection: Connection) -> AsyncGenerator[Connection, None]:
async def perform_handshake(
self, connection: Connection
) -> AsyncGenerator[Connection, None]:
"""
Performs handshake with this new connection, and yields the connection. If the handshake
is unsuccessful, or we already have a connection with this peer, the connection is closed,
and nothing is yielded.
"""
# Send handshake message
outbound_handshake = Message("handshake", Handshake(protocol_version, self._node_id, uint16(self._port), self._local_type))
outbound_handshake = Message(
"handshake",
Handshake(
protocol_version, self._node_id, uint16(self._port), self._local_type
),
)
try:
await connection.send(outbound_handshake)
@ -243,7 +292,11 @@ class ChiaServer:
# Read handshake message
full_message = await connection.read_one_message()
inbound_handshake = Handshake(**full_message.data)
if full_message.function != "handshake" or not inbound_handshake or not inbound_handshake.node_type:
if (
full_message.function != "handshake"
or not inbound_handshake
or not inbound_handshake.node_type
):
raise InvalidHandshake("Invalid handshake")
# Makes sure that we only start one connection with each peer
@ -265,21 +318,33 @@ class ChiaServer:
raise InvalidAck("Invalid ack")
if inbound_handshake.version != protocol_version:
raise IncompatibleProtocolVersion(f"Our node version {protocol_version} is not compatible with peer\
{connection} version {inbound_handshake.version}")
raise IncompatibleProtocolVersion(
f"Our node version {protocol_version} is not compatible with peer\
{connection} version {inbound_handshake.version}"
)
log.info((f"Handshake with {NodeType(connection.connection_type).name} {connection.get_peername()} "
f"{connection.node_id}"
f" established"))
log.info(
(
f"Handshake with {NodeType(connection.connection_type).name} {connection.get_peername()} "
f"{connection.node_id}"
f" established"
)
)
# Only yield a connection if the handshake is succesful and the connection is not a duplicate.
yield connection
except (IncompatibleProtocolVersion, InvalidAck,
InvalidHandshake, asyncio.IncompleteReadError, ConnectionResetError) as e:
except (
IncompatibleProtocolVersion,
InvalidAck,
InvalidHandshake,
asyncio.IncompleteReadError,
ConnectionResetError,
) as e:
log.warning(f"{e}, handshake not completed. Connection not created.")
connection.close()
async def connection_to_message(self, connection: Connection) -> AsyncGenerator[
Tuple[Connection, Message], None]:
async def connection_to_message(
self, connection: Connection
) -> AsyncGenerator[Tuple[Connection, Message], None]:
"""
Async generator which yields complete binary messages from connections,
along with a streamwriter to send back responses. On EOF received, the connection
@ -291,15 +356,20 @@ class ChiaServer:
# Read one message at a time, forever
yield (connection, message)
except asyncio.IncompleteReadError:
log.warning(f"Received EOF from {connection.get_peername()}, closing connection.")
log.warning(
f"Received EOF from {connection.get_peername()}, closing connection."
)
except ConnectionError:
log.warning(f"Connection error by peer {connection.get_peername()}, closing connection.")
log.warning(
f"Connection error by peer {connection.get_peername()}, closing connection."
)
finally:
# Removes the connection from the global list, so we don't try to send things to it
self.global_connections.close(connection, True)
async def handle_message(self, pair: Tuple[Connection, Message], api: Any) -> AsyncGenerator[
Tuple[Connection, OutboundMessage], None]:
async def handle_message(
self, pair: Tuple[Connection, Message], api: Any
) -> AsyncGenerator[Tuple[Connection, OutboundMessage], None]:
"""
Async generator which takes messages, parses, them, executes the right
api function, and yields responses (to same connection, propagated, etc).
@ -325,8 +395,9 @@ class ChiaServer:
log.error(f"Error {type(e)} {e}, closing connection {connection}")
self.global_connections.close(connection)
async def expand_outbound_messages(self, pair: Tuple[Connection, OutboundMessage]) -> AsyncGenerator[
Tuple[Connection, Message], None]:
async def expand_outbound_messages(
self, pair: Tuple[Connection, OutboundMessage]
) -> AsyncGenerator[Tuple[Connection, Message], None]:
"""
Expands each of the outbound messages into it's own message.
"""
@ -339,13 +410,18 @@ class ChiaServer:
elif outbound_message.delivery_method == Delivery.RANDOM:
# Select a random peer.
to_yield_single: Tuple[Connection, Message]
typed_peers: List[Connection] = [peer for peer in self.global_connections.get_connections()
if peer.connection_type == outbound_message.peer_type]
typed_peers: List[Connection] = [
peer
for peer in self.global_connections.get_connections()
if peer.connection_type == outbound_message.peer_type
]
if len(typed_peers) == 0:
return
yield (random.choice(typed_peers), outbound_message.message)
elif (outbound_message.delivery_method == Delivery.BROADCAST or
outbound_message.delivery_method == Delivery.BROADCAST_TO_OTHERS):
elif (
outbound_message.delivery_method == Delivery.BROADCAST
or outbound_message.delivery_method == Delivery.BROADCAST_TO_OTHERS
):
# Broadcast to all peers.
for peer in self.global_connections.get_connections():
if peer.connection_type == outbound_message.peer_type:

View File

@ -7,27 +7,30 @@ from blspy import PrivateKey
from src.farmer import Farmer
from src.protocols.harvester_protocol import HarvesterHandshake
from src.server.outbound_message import (Delivery, Message, NodeType,
OutboundMessage)
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.server.server import ChiaServer
from src.types.peer_info import PeerInfo
from src.util.network import parse_host_port
logging.basicConfig(format='Farmer %(name)-25s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s',
level=logging.INFO,
datefmt='%H:%M:%S'
)
logging.basicConfig(
format="Farmer %(name)-25s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s",
level=logging.INFO,
datefmt="%H:%M:%S",
)
async def main():
farmer = Farmer()
harvester_peer = PeerInfo(farmer.config['harvester_peer']['host'],
farmer.config['harvester_peer']['port'])
full_node_peer = PeerInfo(farmer.config['full_node_peer']['host'],
farmer.config['full_node_peer']['port'])
harvester_peer = PeerInfo(
farmer.config["harvester_peer"]["host"], farmer.config["harvester_peer"]["port"]
)
full_node_peer = PeerInfo(
farmer.config["full_node_peer"]["host"], farmer.config["full_node_peer"]["port"]
)
def signal_received():
server.close_all()
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, signal_received)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, signal_received)
@ -36,11 +39,14 @@ async def main():
async def on_connect():
# Sends a handshake to the harvester
pool_sks: List[PrivateKey] = [PrivateKey.from_bytes(bytes.fromhex(ce))
for ce in farmer.key_config["pool_sks"]]
pool_sks: List[PrivateKey] = [
PrivateKey.from_bytes(bytes.fromhex(ce))
for ce in farmer.key_config["pool_sks"]
]
msg = HarvesterHandshake([sk.get_public_key() for sk in pool_sks])
yield OutboundMessage(NodeType.HARVESTER, Message("harvester_handshake", msg),
Delivery.BROADCAST)
yield OutboundMessage(
NodeType.HARVESTER, Message("harvester_handshake", msg), Delivery.BROADCAST
)
_ = await server.start_server(host, on_connect)
_ = await server.start_client(harvester_peer, None)
@ -48,4 +54,5 @@ async def main():
await server.await_closed()
asyncio.run(main())

View File

@ -11,10 +11,11 @@ from src.server.server import ChiaServer
from src.types.peer_info import PeerInfo
from src.util.network import parse_host_port
logging.basicConfig(format='FullNode %(name)-23s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s',
level=logging.INFO,
datefmt='%H:%M:%S'
)
logging.basicConfig(
format="FullNode %(name)-23s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s",
level=logging.INFO,
datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)
server_closed = False
@ -26,7 +27,7 @@ async def main():
if "-id" in sys.argv:
db_id = int(sys.argv[sys.argv.index("-id") + 1])
store = FullNodeStore(f"fndb_{db_id}")
blockchain = Blockchain(store)
await blockchain.initialize()
@ -50,6 +51,7 @@ async def main():
if ui_close_cb:
ui_close_cb()
master_close_cb()
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, signal_received)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, signal_received)
@ -58,22 +60,34 @@ async def main():
index = sys.argv.index("-u")
ui_ssh_port = int(sys.argv[index + 1])
from src.ui.prompt_ui import start_ssh_server
wait_for_ui, ui_close_cb = start_ssh_server(store, blockchain, server, port, ui_ssh_port,
full_node.config['ssh_filename'], master_close_cb)
connect_to_farmer = ("-f" in sys.argv)
connect_to_timelord = ("-t" in sys.argv)
wait_for_ui, ui_close_cb = start_ssh_server(
store,
blockchain,
server,
port,
ui_ssh_port,
full_node.config["ssh_filename"],
master_close_cb,
)
connect_to_farmer = "-f" in sys.argv
connect_to_timelord = "-t" in sys.argv
full_node._start_bg_tasks()
if connect_to_farmer and not server_closed:
peer_info = PeerInfo(full_node.config['farmer_peer']['host'],
full_node.config['farmer_peer']['port'])
peer_info = PeerInfo(
full_node.config["farmer_peer"]["host"],
full_node.config["farmer_peer"]["port"],
)
_ = await server.start_client(peer_info, None)
if connect_to_timelord and not server_closed:
peer_info = PeerInfo(full_node.config['timelord_peer']['host'],
full_node.config['timelord_peer']['port'])
peer_info = PeerInfo(
full_node.config["timelord_peer"]["host"],
full_node.config["timelord_peer"]["port"],
)
_ = await server.start_client(peer_info, None)
log.info("Waiting to connect to some peers...")
@ -98,5 +112,6 @@ async def main():
await wait_for_ui()
await asyncio.get_running_loop().shutdown_asyncgens()
#asyncio.run(main())
# asyncio.run(main())
FullNodeStore.loop.run_until_complete(main())

View File

@ -8,10 +8,11 @@ from src.server.server import ChiaServer
from src.types.peer_info import PeerInfo
from src.util.network import parse_host_port
logging.basicConfig(format='Harvester %(name)-24s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s',
level=logging.INFO,
datefmt='%H:%M:%S'
)
logging.basicConfig(
format="Harvester %(name)-24s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s",
level=logging.INFO,
datefmt="%H:%M:%S",
)
async def main():
@ -22,13 +23,16 @@ async def main():
def signal_received():
server.close_all()
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, signal_received)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, signal_received)
peer_info = PeerInfo(harvester.config['farmer_peer']['host'],
harvester.config['farmer_peer']['port'])
peer_info = PeerInfo(
harvester.config["farmer_peer"]["host"], harvester.config["farmer_peer"]["port"]
)
_ = await server.start_client(peer_info, None)
await server.await_closed()
asyncio.run(main())

View File

@ -7,10 +7,11 @@ from src.server.outbound_message import NodeType
from src.server.server import ChiaServer
from src.util.network import parse_host_port
logging.basicConfig(format='Introducer %(name)-24s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s',
level=logging.INFO,
datefmt='%H:%M:%S'
)
logging.basicConfig(
format="Introducer %(name)-24s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s",
level=logging.INFO,
datefmt="%H:%M:%S",
)
async def main():
@ -22,9 +23,11 @@ async def main():
def signal_received():
server.close_all()
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, signal_received)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, signal_received)
await server.await_closed()
asyncio.run(main())

View File

@ -8,10 +8,11 @@ from src.timelord import Timelord
from src.types.peer_info import PeerInfo
from src.util.network import parse_host_port
logging.basicConfig(format='Timelord %(name)-25s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s',
level=logging.INFO,
datefmt='%H:%M:%S'
)
logging.basicConfig(
format="Timelord %(name)-25s: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s",
level=logging.INFO,
datefmt="%H:%M:%S",
)
async def main():
@ -23,11 +24,14 @@ async def main():
def signal_received():
server.close_all()
asyncio.create_task(timelord._shutdown())
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, signal_received)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, signal_received)
full_node_peer = PeerInfo(timelord.config['full_node_peer']['host'],
timelord.config['full_node_peer']['port'])
full_node_peer = PeerInfo(
timelord.config["full_node_peer"]["host"],
timelord.config["full_node_peer"]["port"],
)
await server.start_client(full_node_peer, None)
@ -36,4 +40,5 @@ async def main():
await server.await_closed()
asyncio.run(main())

View File

@ -14,8 +14,7 @@ from lib.chiavdf.inkfish.create_discriminant import create_discriminant
from lib.chiavdf.inkfish.proof_of_time import check_proof_of_time_nwesolowski
from src.consensus.constants import constants
from src.protocols import timelord_protocol
from src.server.outbound_message import (Delivery, Message, NodeType,
OutboundMessage)
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.types.classgroup import ClassgroupElement
from src.types.proof_of_time import ProofOfTime
from src.types.sized_bytes import bytes32
@ -29,8 +28,9 @@ class Timelord:
def __init__(self):
config_filename = os.path.join(ROOT_DIR, "src", "config", "config.yaml")
self.config = safe_load(open(config_filename, "r"))["timelord"]
self.free_servers: List[Tuple[str, str]] = list(zip(self.config["vdf_server_ips"],
self.config["vdf_server_ports"]))
self.free_servers: List[Tuple[str, str]] = list(
zip(self.config["vdf_server_ips"], self.config["vdf_server_ports"])
)
self.lock: Lock = Lock()
self.server_count: int = len(self.free_servers)
self.active_discriminants: Dict[bytes32, Tuple[StreamWriter, uint64, str]] = {}
@ -47,8 +47,11 @@ class Timelord:
async def _shutdown(self):
async with self.lock:
for stop_discriminant, (stop_writer, _, _) in self.active_discriminants.items():
stop_writer.write(b'10')
for (
stop_discriminant,
(stop_writer, _, _),
) in self.active_discriminants.items():
stop_writer.write(b"10")
await stop_writer.drain()
self.done_discriminants.append(stop_discriminant)
self.active_discriminants.clear()
@ -61,27 +64,51 @@ class Timelord:
stop_writer: Optional[StreamWriter] = None
stop_discriminant: Optional[bytes32] = None
low_weights = {k: v for k, v in self.active_discriminants.items() if v[1] == worst_weight_active}
no_iters = {k: v for k, v in low_weights.items()
if k not in self.pending_iters or len(self.pending_iters[k]) == 0}
low_weights = {
k: v
for k, v in self.active_discriminants.items()
if v[1] == worst_weight_active
}
no_iters = {
k: v
for k, v in low_weights.items()
if k not in self.pending_iters or len(self.pending_iters[k]) == 0
}
# If we have process(es) with no iters, stop the one that started the latest
if len(no_iters) > 0:
latest_start_time = max([self.active_discriminants_start_time[k] for k, _ in no_iters.items()])
stop_discriminant, stop_writer = next((k, v[0]) for k, v in no_iters.items()
if self.active_discriminants_start_time[k] == latest_start_time)
latest_start_time = max(
[self.active_discriminants_start_time[k] for k, _ in no_iters.items()]
)
stop_discriminant, stop_writer = next(
(k, v[0])
for k, v in no_iters.items()
if self.active_discriminants_start_time[k] == latest_start_time
)
else:
# Otherwise, pick the one that finishes one proof the latest.
best_iter = {k: min(self.pending_iters[k]) for k, _ in low_weights.items()}
time_taken = {k: time.time() - self.active_discriminants_start_time[k] for k, _ in low_weights.items()}
expected_finish = {k: max(0, (best_iter[k] - time_taken[k] * self.avg_ips[v[2]][0]) / self.avg_ips[v[2]][0])
for k, v in low_weights.items()}
time_taken = {
k: time.time() - self.active_discriminants_start_time[k]
for k, _ in low_weights.items()
}
expected_finish = {
k: max(
0,
(best_iter[k] - time_taken[k] * self.avg_ips[v[2]][0])
/ self.avg_ips[v[2]][0],
)
for k, v in low_weights.items()
}
worst_finish = max([v for v in expected_finish.values()])
log.info(f"Worst finish time: {worst_finish}s")
stop_discriminant, stop_writer = next((k, v[0]) for k, v in low_weights.items()
if expected_finish[k] == worst_finish)
stop_discriminant, stop_writer = next(
(k, v[0])
for k, v in low_weights.items()
if expected_finish[k] == worst_finish
)
assert stop_writer is not None
stop_writer.write(b'10')
stop_writer.write(b"10")
await stop_writer.drain()
del self.active_discriminants[stop_discriminant]
del self.active_discriminants_start_time[stop_discriminant]
@ -90,10 +117,14 @@ class Timelord:
async def _update_avg_ips(self, challenge_hash, iterations_needed, ip):
async with self.lock:
if challenge_hash in self.active_discriminants:
time_taken = time.time() - self.active_discriminants_start_time[challenge_hash]
ips = int(iterations_needed / time_taken * 10)/10
log.info(f"Finished PoT, chall:{challenge_hash[:10].hex()}.."
f" {iterations_needed} iters. {int(time_taken*1000)/1000}s, {ips} ips")
time_taken = (
time.time() - self.active_discriminants_start_time[challenge_hash]
)
ips = int(iterations_needed / time_taken * 10) / 10
log.info(
f"Finished PoT, chall:{challenge_hash[:10].hex()}.."
f" {iterations_needed} iters. {int(time_taken*1000)/1000}s, {ips} ips"
)
if ip not in self.avg_ips:
self.avg_ips[ip] = (ips, 1)
else:
@ -103,34 +134,46 @@ class Timelord:
log.info(f"New estimate: {new_avg_ips}")
self.pending_iters[challenge_hash].remove(iterations_needed)
else:
log.info(f"Finished PoT chall:{challenge_hash[:10].hex()}.. {iterations_needed}"
f" iters. But challenge not active anymore")
log.info(
f"Finished PoT chall:{challenge_hash[:10].hex()}.. {iterations_needed}"
f" iters. But challenge not active anymore"
)
async def _update_proofs_count(self, challenge_weight):
async with self.lock:
if (challenge_weight not in self.proof_count):
if challenge_weight not in self.proof_count:
self.proof_count[challenge_weight] = 1
else:
self.proof_count[challenge_weight] += 1
if (self.proof_count[challenge_weight] >= 3):
if self.proof_count[challenge_weight] >= 3:
log.info("Cleaning up servers")
self.best_weight_three_proofs = max(self.best_weight_three_proofs, challenge_weight)
self.best_weight_three_proofs = max(
self.best_weight_three_proofs, challenge_weight
)
for active_disc in list(self.active_discriminants):
current_writer, current_weight, _ = self.active_discriminants[active_disc]
if (current_weight <= challenge_weight):
current_writer, current_weight, _ = self.active_discriminants[
active_disc
]
if current_weight <= challenge_weight:
log.info(f"Active weight cleanup: {current_weight}")
log.info(f"Cleanup weight: {challenge_weight}")
current_writer.write(b'10')
current_writer.write(b"10")
await current_writer.drain()
del self.active_discriminants[active_disc]
del self.active_discriminants_start_time[active_disc]
self.done_discriminants.append(active_disc)
async def _do_process_communication(self, challenge_hash, challenge_weight, ip, port):
disc: int = create_discriminant(challenge_hash, constants["DISCRIMINANT_SIZE_BITS"])
async def _do_process_communication(
self, challenge_hash, challenge_weight, ip, port
):
disc: int = create_discriminant(
challenge_hash, constants["DISCRIMINANT_SIZE_BITS"]
)
log.info("Attempting SSH connection")
proc = await asyncio.create_subprocess_shell(f"./lib/chiavdf/fast_vdf/vdf_server {port}")
proc = await asyncio.create_subprocess_shell(
f"./lib/chiavdf/fast_vdf/vdf_server {port}"
)
# TODO(Florin): Handle connection failure (attempt another server)
writer: Optional[StreamWriter] = None
@ -150,7 +193,7 @@ class Timelord:
await writer.drain()
ok = await reader.readexactly(2)
assert(ok.decode() == "OK")
assert ok.decode() == "OK"
log.info("Got handshake with VDF server.")
@ -159,7 +202,7 @@ class Timelord:
self.active_discriminants_start_time[challenge_hash] = time.time()
async with self.lock:
if (challenge_hash in self.pending_iters):
if challenge_hash in self.pending_iters:
log.info(f"Writing pending iters {challenge_hash}")
for iter in sorted(self.pending_iters[challenge_hash]):
writer.write((str(len(str(iter))) + str(iter)).encode())
@ -173,7 +216,7 @@ class Timelord:
log.warn(f"{type(e)} {e}")
break
if (data.decode() == "STOP"):
if data.decode() == "STOP":
log.info("Stopped server")
writer.write(b"ACK")
await writer.drain()
@ -184,14 +227,17 @@ class Timelord:
len_server = len(self.free_servers)
log.info(f"Process ended... Server length {len_server}")
break
elif (data.decode() == "POLL"):
elif data.decode() == "POLL":
async with self.lock:
# If I have a newer discriminant... Free up the VDF server
if (len(self.discriminant_queue) > 0 and
challenge_weight < max([h for _, h in self.discriminant_queue])
and challenge_hash in self.active_discriminants):
if (
len(self.discriminant_queue) > 0
and challenge_weight
< max([h for _, h in self.discriminant_queue])
and challenge_hash in self.active_discriminants
):
log.info("Got poll, stopping the challenge!")
writer.write(b'10')
writer.write(b"10")
await writer.drain()
del self.active_discriminants[challenge_hash]
del self.active_discriminants_start_time[challenge_hash]
@ -200,58 +246,96 @@ class Timelord:
try:
# This must be a proof, read the continuation.
proof = await reader.readexactly(1860)
stdout_bytes_io: io.BytesIO = io.BytesIO(bytes.fromhex(data.decode() + proof.decode()))
stdout_bytes_io: io.BytesIO = io.BytesIO(
bytes.fromhex(data.decode() + proof.decode())
)
except Exception as e:
e_to_str = str(e)
log.error(f"Socket error: {e_to_str}")
iterations_needed = uint64(int.from_bytes(stdout_bytes_io.read(8), "big", signed=True))
iterations_needed = uint64(
int.from_bytes(stdout_bytes_io.read(8), "big", signed=True)
)
y = ClassgroupElement.parse(stdout_bytes_io)
proof_bytes: bytes = stdout_bytes_io.read()
# Verifies our own proof just in case
proof_blob = ClassGroup.from_ab_discriminant(y.a, y.b, disc).serialize() + proof_bytes
proof_blob = (
ClassGroup.from_ab_discriminant(y.a, y.b, disc).serialize()
+ proof_bytes
)
x = ClassGroup.from_ab_discriminant(2, 1, disc)
if (not check_proof_of_time_nwesolowski(disc, x, proof_blob, iterations_needed,
constants["DISCRIMINANT_SIZE_BITS"],
self.config["n_wesolowski"])):
if not check_proof_of_time_nwesolowski(
disc,
x,
proof_blob,
iterations_needed,
constants["DISCRIMINANT_SIZE_BITS"],
self.config["n_wesolowski"],
):
log.error("My proof is incorrect!")
output = ClassgroupElement(y.a, y.b)
proof_of_time = ProofOfTime(challenge_hash, iterations_needed, output,
self.config['n_wesolowski'], [uint8(b) for b in proof_bytes])
proof_of_time = ProofOfTime(
challenge_hash,
iterations_needed,
output,
self.config["n_wesolowski"],
[uint8(b) for b in proof_bytes],
)
response = timelord_protocol.ProofOfTimeFinished(proof_of_time)
await self._update_avg_ips(challenge_hash, iterations_needed, ip)
async with self.lock:
self.proofs_to_write.append(OutboundMessage(NodeType.FULL_NODE,
Message("proof_of_time_finished", response),
Delivery.BROADCAST))
self.proofs_to_write.append(
OutboundMessage(
NodeType.FULL_NODE,
Message("proof_of_time_finished", response),
Delivery.BROADCAST,
)
)
await self._update_proofs_count(challenge_weight)
async def _manage_discriminant_queue(self):
while not self.is_shutdown:
async with self.lock:
if (len(self.discriminant_queue) > 0):
if len(self.discriminant_queue) > 0:
max_weight = max([h for _, h in self.discriminant_queue])
if (max_weight <= self.best_weight_three_proofs):
self.done_discriminants.extend([d for d, _ in self.discriminant_queue])
if max_weight <= self.best_weight_three_proofs:
self.done_discriminants.extend(
[d for d, _ in self.discriminant_queue]
)
self.discriminant_queue.clear()
else:
disc = next(d for d, h in self.discriminant_queue if h == max_weight)
if (len(self.free_servers) != 0):
disc = next(
d for d, h in self.discriminant_queue if h == max_weight
)
if len(self.free_servers) != 0:
ip, port = self.free_servers[0]
self.free_servers = self.free_servers[1:]
self.discriminant_queue.remove((disc, max_weight))
asyncio.create_task(self._do_process_communication(disc, max_weight, ip, port))
asyncio.create_task(
self._do_process_communication(
disc, max_weight, ip, port
)
)
else:
if (len(self.active_discriminants) == self.server_count):
worst_weight_active = min([h for (_, h, _) in self.active_discriminants.values()])
if (max_weight > worst_weight_active):
if len(self.active_discriminants) == self.server_count:
worst_weight_active = min(
[
h
for (
_,
h,
_,
) in self.active_discriminants.values()
]
)
if max_weight > worst_weight_active:
await self._stop_worst_process(worst_weight_active)
if (len(self.proofs_to_write) > 0):
if len(self.proofs_to_write) > 0:
for msg in self.proofs_to_write:
yield msg
self.proofs_to_write.clear()
@ -264,31 +348,47 @@ class Timelord:
should be started on it. We add the challenge into the queue if it's worth it to have.
"""
async with self.lock:
if (challenge_start.challenge_hash in self.seen_discriminants):
log.info(f"Already seen this challenge hash {challenge_start.challenge_hash}. Ignoring.")
if challenge_start.challenge_hash in self.seen_discriminants:
log.info(
f"Already seen this challenge hash {challenge_start.challenge_hash}. Ignoring."
)
return
if (challenge_start.weight <= self.best_weight_three_proofs):
if challenge_start.weight <= self.best_weight_three_proofs:
log.info("Not starting challenge, already three proofs at that weight")
return
self.seen_discriminants.append(challenge_start.challenge_hash)
self.discriminant_queue.append((challenge_start.challenge_hash, challenge_start.weight))
self.discriminant_queue.append(
(challenge_start.challenge_hash, challenge_start.weight)
)
log.info("Appended to discriminant queue.")
@api_request
async def proof_of_space_info(self, proof_of_space_info: timelord_protocol.ProofOfSpaceInfo):
async def proof_of_space_info(
self, proof_of_space_info: timelord_protocol.ProofOfSpaceInfo
):
"""
Notification from full node about a new proof of space for a challenge. If we already
have a process for this challenge, we should communicate to the process to tell it how
many iterations to run for.
"""
async with self.lock:
if (proof_of_space_info.challenge_hash in self.active_discriminants):
writer, _, _ = self.active_discriminants[proof_of_space_info.challenge_hash]
writer.write(((str(len(str(proof_of_space_info.iterations_needed))) +
str(proof_of_space_info.iterations_needed)).encode()))
if proof_of_space_info.challenge_hash in self.active_discriminants:
writer, _, _ = self.active_discriminants[
proof_of_space_info.challenge_hash
]
writer.write(
(
(
str(len(str(proof_of_space_info.iterations_needed)))
+ str(proof_of_space_info.iterations_needed)
).encode()
)
)
await writer.drain()
elif (proof_of_space_info.challenge_hash in self.done_discriminants):
elif proof_of_space_info.challenge_hash in self.done_discriminants:
return
if (proof_of_space_info.challenge_hash not in self.pending_iters):
if proof_of_space_info.challenge_hash not in self.pending_iters:
self.pending_iters[proof_of_space_info.challenge_hash] = []
self.pending_iters[proof_of_space_info.challenge_hash].append(proof_of_space_info.iterations_needed)
self.pending_iters[proof_of_space_info.challenge_hash].append(
proof_of_space_info.iterations_needed
)

View File

@ -23,7 +23,7 @@ class FullBlock(Streamable):
@property
def weight(self) -> uint64:
if (self.header_block.challenge):
if self.header_block.challenge:
return self.header_block.challenge.total_weight
else:
return uint64(0)

View File

@ -25,8 +25,9 @@ class ProofOfSpace(Streamable):
def verify_and_get_quality(self) -> Optional[bytes32]:
v: Verifier = Verifier()
plot_seed: bytes32 = self.get_plot_seed()
quality_str = v.validate_proof(plot_seed, self.size, self.challenge_hash,
bytes(self.proof))
quality_str = v.validate_proof(
plot_seed, self.size, self.challenge_hash, bytes(self.proof)
)
if not quality_str:
return None
return self.quality_str_to_quality(self.challenge_hash, quality_str)

View File

@ -20,12 +20,14 @@ class ProofOfTime(Streamable):
witness: List[uint8]
def is_valid(self, discriminant_size_bits):
disc: int = create_discriminant(self.challenge_hash,
discriminant_size_bits)
disc: int = create_discriminant(self.challenge_hash, discriminant_size_bits)
x = ClassGroup.from_ab_discriminant(2, 1, disc)
y = ClassGroup.from_ab_discriminant(self.output.a,
self.output.b, disc)
return check_proof_of_time_nwesolowski(disc, x, y.serialize() + bytes(self.witness),
self.number_of_iterations,
discriminant_size_bits,
self.witness_type)
y = ClassGroup.from_ab_discriminant(self.output.a, self.output.b, disc)
return check_proof_of_time_nwesolowski(
disc,
x,
y.serialize() + bytes(self.witness),
self.number_of_iterations,
discriminant_size_bits,
self.witness_type,
)

View File

@ -9,8 +9,7 @@ from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.layout.containers import HSplit, VSplit, Window
from prompt_toolkit.layout.layout import Layout
from prompt_toolkit.styles import Style
from prompt_toolkit.widgets import (Button, Frame, Label, SearchToolbar,
TextArea)
from prompt_toolkit.widgets import Button, Frame, Label, SearchToolbar, TextArea
from src.blockchain import Blockchain
from src.db.database import FullNodeStore
from src.server.connection import NodeType, PeerConnections
@ -24,8 +23,15 @@ from src.util.ints import uint16
log = logging.getLogger(__name__)
def start_ssh_server(store: FullNodeStore, blockchain: Blockchain, server: ChiaServer,
port: int, ssh_port: int, ssh_key_filename: str, close_cb: Callable):
def start_ssh_server(
store: FullNodeStore,
blockchain: Blockchain,
server: ChiaServer,
port: int,
ssh_port: int,
ssh_key_filename: str,
close_cb: Callable,
):
"""
Starts an SSH Server that creates FullNodeUI instances whenever someone connects to the port.
returns a coroutine that can be awaited, which returns when all ui instances have been closed.
@ -57,12 +63,14 @@ def start_ssh_server(store: FullNodeStore, blockchain: Blockchain, server: ChiaS
uis.append(ui)
await ui.app.run_async()
asyncio.get_running_loop().create_task(asyncssh.create_server(
lambda: PromptToolkitSSHServer(interact),
"",
ssh_port,
server_host_keys=[ssh_key_filename],
))
asyncio.get_running_loop().create_task(
asyncssh.create_server(
lambda: PromptToolkitSSHServer(interact),
"",
ssh_port,
server_host_keys=[ssh_key_filename],
)
)
return await_all_closed, ui_close_cb
@ -72,8 +80,15 @@ class FullNodeUI:
when the full node is closed. Uses store, blockchain, and connections, to display relevant
information. The UI is updated periodically.
"""
def __init__(self, store: FullNodeStore, blockchain: Blockchain, server: ChiaServer,
port: int, parent_close_cb: Callable):
def __init__(
self,
store: FullNodeStore,
blockchain: Blockchain,
server: ChiaServer,
port: int,
parent_close_cb: Callable,
):
self.port: int = port
self.store: FullNodeStore = store
self.blockchain: Blockchain = blockchain
@ -89,11 +104,14 @@ class FullNodeUI:
self.parent_close_cb = parent_close_cb
self.kb = self.setup_keybindings()
self.draw_initial()
self.style = Style([
('error', '#ff0044'),
])
self.app = Application(style=self.style, layout=self.layout, full_screen=True,
key_bindings=self.kb, mouse_support=True)
self.style = Style([("error", "#ff0044"),])
self.app = Application(
style=self.style,
layout=self.layout,
full_screen=True,
key_bindings=self.kb,
mouse_support=True,
)
self.closed = False
self.update_task = asyncio.get_running_loop().create_task(self.update())
@ -123,6 +141,7 @@ class FullNodeUI:
@kb.add("c-c")
def exit_(event):
self.close()
return kb
def draw_initial(self):
@ -130,8 +149,8 @@ class FullNodeUI:
self.empty_row = TextArea(focusable=False, height=1)
# home/
self.loading_msg = Label(text=f'Initializing UI....')
self.server_msg = Label(text=f'Server running on port {self.port}.')
self.loading_msg = Label(text=f"Initializing UI....")
self.server_msg = Label(text=f"Server running on port {self.port}.")
self.syncing = TextArea(focusable=False, height=1)
self.current_heads_label = TextArea(focusable=False, height=1)
self.lca_label = TextArea(focusable=False, height=1)
@ -141,45 +160,64 @@ class FullNodeUI:
self.con_rows = []
self.connections_msg = Label(text=f"Connections")
self.connection_rows_vsplit = Window()
self.add_connection_msg = Label(text=f'Add a connection ip:port')
self.add_connection_field = TextArea(height=1, prompt='>>> ', style='class:input-field',
multiline=False, wrap_lines=False, search_field=search_field)
self.add_connection_field.accept_handler = self.async_to_sync(self.add_connection)
self.latest_blocks_msg = Label(text=f'Latest blocks')
self.latest_blocks_labels = [Button(text="block") for _ in range(self.num_blocks)]
self.add_connection_msg = Label(text=f"Add a connection ip:port")
self.add_connection_field = TextArea(
height=1,
prompt=">>> ",
style="class:input-field",
multiline=False,
wrap_lines=False,
search_field=search_field,
)
self.add_connection_field.accept_handler = self.async_to_sync(
self.add_connection
)
self.latest_blocks_msg = Label(text=f"Latest blocks")
self.latest_blocks_labels = [
Button(text="block") for _ in range(self.num_blocks)
]
self.search_block_msg = Label(text=f'Search block by hash')
self.search_block_field = TextArea(height=1, prompt='>>> ', style='class:input-field',
multiline=False, wrap_lines=False, search_field=search_field)
self.search_block_msg = Label(text=f"Search block by hash")
self.search_block_field = TextArea(
height=1,
prompt=">>> ",
style="class:input-field",
multiline=False,
wrap_lines=False,
search_field=search_field,
)
self.search_block_field.accept_handler = self.async_to_sync(self.search_block)
self.close_ui_button = Button('Close UI', handler=self.close)
self.quit_button = Button('Stop node and close UI', handler=self.stop)
self.error_msg = Label(style='class:error', text=f'')
self.close_ui_button = Button("Close UI", handler=self.close)
self.quit_button = Button("Stop node and close UI", handler=self.stop)
self.error_msg = Label(style="class:error", text=f"")
# block/
self.block_msg = Label(text=f'Block')
self.block_msg = Label(text=f"Block")
self.block_label = TextArea(focusable=True, scrollbar=True, focus_on_click=True)
self.back_button = Button(text="Back", handler=self.change_route_handler("home/"))
self.challenge_msg = Label(text=f'Block Header')
self.back_button = Button(
text="Back", handler=self.change_route_handler("home/")
)
self.challenge_msg = Label(text=f"Block Header")
self.challenge = TextArea(focusable=False)
body = HSplit([self.loading_msg, self.server_msg], height=D(), width=D())
self.content = Frame(title="Chia Full Node", body=body)
self.layout = Layout(VSplit([self.content], height=D(), width=D()))
def change_route_handler(self, route):
def change_route():
self.prev_route = self.route
self.route = route
self.focused = False
self.error_msg.text = ""
return change_route
def async_to_sync(self, coroutine):
def inner(buff):
asyncio.get_running_loop().create_task(coroutine(buff.text))
return inner
async def search_block(self, text: str):
@ -198,7 +236,9 @@ class FullNodeUI:
try:
ip, port = text.split(":")
except ValueError: # Not yet in layout
self.error_msg.text = "Enter a valid IP and port in the following format: 10.5.4.3:8000"
self.error_msg.text = (
"Enter a valid IP and port in the following format: 10.5.4.3:8000"
)
return
target_node: PeerInfo = PeerInfo(ip, uint16(int(port)))
log.error(f"Want to connect to {ip}, {port}")
@ -229,6 +269,7 @@ class FullNodeUI:
labels = [row.children[0].content.text() for row in self.con_rows]
if con_str not in labels:
con_label = Label(text=con_str)
def disconnect():
con.close()
self.layout.focus(self.quit_button)
@ -305,20 +346,35 @@ class FullNodeUI:
self.focused = True
except ValueError: # Not yet in layout
pass
return HSplit([self.server_msg, self.syncing, self.lca_label, self.current_heads_label,
self.difficulty_label, self.ips_label, self.total_iters_label,
Window(height=1, char='-', style='class:line'),
self.connections_msg,
new_con_rows,
Window(height=1, char='-', style='class:line'),
self.add_connection_msg,
self.add_connection_field,
Window(height=1, char='-', style='class:line'),
self.latest_blocks_msg, *new_labels,
Window(height=1, char='-', style='class:line'),
self.search_block_msg, self.search_block_field,
Window(height=1, char='-', style='class:line'),
self.close_ui_button, self.quit_button, self.error_msg], width=D(), height=D())
return HSplit(
[
self.server_msg,
self.syncing,
self.lca_label,
self.current_heads_label,
self.difficulty_label,
self.ips_label,
self.total_iters_label,
Window(height=1, char="-", style="class:line"),
self.connections_msg,
new_con_rows,
Window(height=1, char="-", style="class:line"),
self.add_connection_msg,
self.add_connection_field,
Window(height=1, char="-", style="class:line"),
self.latest_blocks_msg,
*new_labels,
Window(height=1, char="-", style="class:line"),
self.search_block_msg,
self.search_block_field,
Window(height=1, char="-", style="class:line"),
self.close_ui_button,
self.quit_button,
self.error_msg,
],
width=D(),
height=D(),
)
async def draw_block(self):
block_hash: str = self.route.split("block/")[1]
@ -338,7 +394,9 @@ class FullNodeUI:
self.focused = True
except ValueError: # Not yet in layout
pass
return HSplit([self.block_msg, self.block_label, self.back_button], width=D(), height=D())
return HSplit(
[self.block_msg, self.block_label, self.back_button], width=D(), height=D()
)
async def update(self):
try:

View File

@ -12,6 +12,7 @@ def api_request(f):
def new_challenge(challenge):
# handle request
"""
@functools.wraps(f)
def f_substitute(*args, **kwargs):
sig = signature(f)
@ -27,4 +28,5 @@ def api_request(f):
inter[param_name] = param_class(**inter[param_name])
return f(**inter)
return f_substitute

View File

@ -40,7 +40,14 @@ def make_sized_bytes(size):
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, str(self))
namespace = dict(__new__=__new__, parse=parse, stream=stream, from_bytes=from_bytes,
__bytes__=__bytes__, __str__=__str__, __repr__=__repr__)
namespace = dict(
__new__=__new__,
parse=parse,
stream=stream,
from_bytes=from_bytes,
__bytes__=__bytes__,
__str__=__str__,
__repr__=__repr__,
)
return type(name, (bytes,), namespace)

View File

@ -9,4 +9,4 @@ def cbor_message(cls: Any) -> Type:
they are the right type.
"""
cls1 = strictdataclass(cls=cls)
return type(cls.__name__, (cls1,), {'__cbor_message__': True})
return type(cls.__name__, (cls1,), {"__cbor_message__": True})

View File

@ -6,8 +6,8 @@ from src.types.sized_bytes import bytes32
def parse_host_port(api) -> Tuple[str, int]:
host: str = sys.argv[1] if len(sys.argv) >= 3 else api.config['host']
port: int = int(sys.argv[2]) if len(sys.argv) >= 3 else api.config['port']
host: str = sys.argv[1] if len(sys.argv) >= 3 else api.config["host"]
port: int = int(sys.argv[2]) if len(sys.argv) >= 3 else api.config["port"]
return (host, port)

View File

@ -3,9 +3,11 @@ def partial_async_gen(f, *args):
Returns an async generator function which is equalivalent to the passed in function,
but only takes in one parameter (the first one).
"""
async def inner(first_param):
async for x in f(first_param, *args):
yield x
return inner
@ -14,6 +16,8 @@ def partial_async(f, *args):
Returns an async function which is equalivalent to the passed in function,
but only takes in one parameter (the first one).
"""
async def inner(first_param):
return await f(first_param, *args)
return inner

View File

@ -7,14 +7,24 @@ import pprint
from hashlib import sha256
from typing import Any, BinaryIO, List, Type, get_type_hints
from blspy import (ChainCode, ExtendedPrivateKey, ExtendedPublicKey,
InsecureSignature, PrependSignature, PrivateKey, PublicKey,
Signature)
from blspy import (
ChainCode,
ExtendedPrivateKey,
ExtendedPublicKey,
InsecureSignature,
PrependSignature,
PrivateKey,
PublicKey,
Signature,
)
from src.types.sized_bytes import bytes32
from src.util.ints import uint32
from src.util.type_checking import (is_type_List, is_type_SpecificOptional,
strictdataclass)
from src.util.type_checking import (
is_type_List,
is_type_SpecificOptional,
strictdataclass,
)
pp = pprint.PrettyPrinter(indent=1, width=120, compact=True)
@ -27,10 +37,18 @@ size_hints = {
"PrependSignature": PrependSignature.SIGNATURE_SIZE,
"ExtendedPublicKey": ExtendedPublicKey.EXTENDED_PUBLIC_KEY_SIZE,
"ExtendedPrivateKey": ExtendedPrivateKey.EXTENDED_PRIVATE_KEY_SIZE,
"ChainCode": ChainCode.CHAIN_CODE_KEY_SIZE
"ChainCode": ChainCode.CHAIN_CODE_KEY_SIZE,
}
unhashable_types = [PrivateKey, PublicKey, Signature, PrependSignature, InsecureSignature,
ExtendedPublicKey, ExtendedPrivateKey, ChainCode]
unhashable_types = [
PrivateKey,
PublicKey,
Signature,
PrependSignature,
InsecureSignature,
ExtendedPublicKey,
ExtendedPrivateKey,
ChainCode,
]
def streamable(cls: Any):
@ -88,7 +106,7 @@ class Streamable:
return f_type.from_bytes(f.read(size_hints[f_type.__name__]))
if f_type is str:
str_size: uint32 = uint32(int.from_bytes(f.read(4), "big"))
return bytes.decode(f.read(str_size), 'utf-8')
return bytes.decode(f.read(str_size), "utf-8")
else:
raise RuntimeError(f"Type {f_type} does not have parse")
@ -121,7 +139,7 @@ class Streamable:
f.write(bytes(item))
elif f_type is str:
f.write(uint32(len(item)).to_bytes(4, "big"))
f.write(item.encode('utf-8'))
f.write(item.encode("utf-8"))
else:
raise NotImplementedError(f"can't stream {item}, {f_type}")

View File

@ -1,4 +1,3 @@
import io
import struct
from typing import Any, BinaryIO
@ -10,12 +9,15 @@ class StructStream(int):
"""
Create a class that can parse and stream itself based on a struct.pack template string.
"""
def __new__(cls: Any, value: int):
bits = struct.calcsize(cls.PACK) * 8
value = int(value)
if value.bit_length() > bits:
raise ValueError(f"Value {value} of size {value.bit_length()} does not fit into "
f"{cls.__name__} of size {bits}")
raise ValueError(
f"Value {value} of size {value.bit_length()} does not fit into "
f"{cls.__name__} of size {bits}"
)
return int.__new__(cls, value) # type: ignore

View File

@ -3,25 +3,31 @@ from typing import Any, List, Type, Union, get_type_hints
def is_type_List(f_type: Type) -> bool:
return (hasattr(f_type, "__origin__") and f_type.__origin__ == list) or f_type == list
return (
hasattr(f_type, "__origin__") and f_type.__origin__ == list
) or f_type == list
def is_type_SpecificOptional(f_type) -> bool:
"""
Returns true for types such as Optional[T], but not Optional, or T.
"""
return (hasattr(f_type, "__origin__") and f_type.__origin__ == Union
and f_type.__args__[1]() is None)
return (
hasattr(f_type, "__origin__")
and f_type.__origin__ == Union
and f_type.__args__[1]() is None
)
def strictdataclass(cls: Any):
class _Local():
class _Local:
"""
Dataclass where all fields must be type annotated, and type checking is performed
at initialization, even recursively through Lists. Non-annotated fields are ignored.
Also, for any fields which have a type with .from_bytes(bytes) or constructor(bytes),
bytes can be passed in and the type can be constructed.
"""
def parse_item(self, item: Any, f_name: str, f_type: Type) -> Any:
if is_type_List(f_type):
collected_list: List = []
@ -53,7 +59,9 @@ def strictdataclass(cls: Any):
for (f_name, f_type) in fields.items():
if f_name not in data:
raise ValueError(f"Field {f_name} not present")
object.__setattr__(self, f_name, self.parse_item(data[f_name], f_name, f_type))
object.__setattr__(
self, f_name, self.parse_item(data[f_name], f_name, f_type)
)
class NoTypeChecking:
__no_type_check__ = True

View File

@ -32,12 +32,14 @@ k: uint8 = uint8(19)
# Uses many plots for testing, in order to guarantee proofs of space at every height
num_plots = 80
# Use the empty string as the seed for the private key
pool_sk: PrivateKey = PrivateKey.from_seed(b'')
pool_sk: PrivateKey = PrivateKey.from_seed(b"")
pool_pk: PublicKey = pool_sk.get_public_key()
plot_sks: List[PrivateKey] = [PrivateKey.from_seed(pn.to_bytes(4, "big")) for pn in range(num_plots)]
plot_sks: List[PrivateKey] = [
PrivateKey.from_seed(pn.to_bytes(4, "big")) for pn in range(num_plots)
]
plot_pks: List[PublicKey] = [sk.get_public_key() for sk in plot_sks]
farmer_sk: PrivateKey = PrivateKey.from_seed(b'coinbase')
farmer_sk: PrivateKey = PrivateKey.from_seed(b"coinbase")
coinbase_target = sha256(bytes(farmer_sk.get_public_key())).digest()
fee_target = sha256(bytes(farmer_sk.get_public_key())).digest()
n_wesolowski = uint8(3)
@ -49,10 +51,20 @@ class BlockTools:
"""
def __init__(self):
plot_seeds: List[bytes32] = [ProofOfSpace.calculate_plot_seed(pool_pk, plot_pk) for plot_pk in plot_pks]
self.filenames: List[str] = [os.path.join("tests", "plots", "genesis-plots-" + str(k) +
sha256(int.to_bytes(i, 4, "big")).digest().hex() + ".dat")
for i in range(num_plots)]
plot_seeds: List[bytes32] = [
ProofOfSpace.calculate_plot_seed(pool_pk, plot_pk) for plot_pk in plot_pks
]
self.filenames: List[str] = [
os.path.join(
"tests",
"plots",
"genesis-plots-"
+ str(k)
+ sha256(int.to_bytes(i, 4, "big")).digest().hex()
+ ".dat",
)
for i in range(num_plots)
]
done_filenames = set()
try:
for pn, filename in enumerate(self.filenames):
@ -66,12 +78,14 @@ class BlockTools:
os.remove(filename)
sys.exit(1)
def get_consecutive_blocks(self,
input_constants: Dict,
num_blocks: int,
block_list: List[FullBlock] = [],
seconds_per_block=constants["BLOCK_TIME_TARGET"],
seed: bytes = b'') -> List[FullBlock]:
def get_consecutive_blocks(
self,
input_constants: Dict,
num_blocks: int,
block_list: List[FullBlock] = [],
seconds_per_block=constants["BLOCK_TIME_TARGET"],
seed: bytes = b"",
) -> List[FullBlock]:
test_constants: Dict[str, Any] = constants.copy()
for key, value in input_constants.items():
test_constants[key] = value
@ -80,11 +94,17 @@ class BlockTools:
if "GENESIS_BLOCK" in test_constants:
block_list.append(FullBlock.from_bytes(test_constants["GENESIS_BLOCK"]))
else:
block_list.append(self.create_genesis_block(test_constants, sha256(seed).digest(), seed))
block_list.append(
self.create_genesis_block(
test_constants, sha256(seed).digest(), seed
)
)
prev_difficulty = test_constants["DIFFICULTY_STARTING"]
curr_difficulty = prev_difficulty
curr_ips = test_constants["VDF_IPS_STARTING"]
elif len(block_list) < (test_constants["DIFFICULTY_EPOCH"] + test_constants["DIFFICULTY_DELAY"]):
elif len(block_list) < (
test_constants["DIFFICULTY_EPOCH"] + test_constants["DIFFICULTY_DELAY"]
):
# First epoch (+delay), so just get first difficulty
prev_difficulty = block_list[0].weight
curr_difficulty = block_list[0].weight
@ -92,23 +112,35 @@ class BlockTools:
curr_ips = test_constants["VDF_IPS_STARTING"]
else:
curr_difficulty = block_list[-1].weight - block_list[-2].weight
prev_difficulty = (block_list[-1 - test_constants["DIFFICULTY_EPOCH"]].weight -
block_list[-2 - test_constants["DIFFICULTY_EPOCH"]].weight)
prev_difficulty = (
block_list[-1 - test_constants["DIFFICULTY_EPOCH"]].weight
- block_list[-2 - test_constants["DIFFICULTY_EPOCH"]].weight
)
assert block_list[-1].header_block.proof_of_time
curr_ips = calculate_ips_from_iterations(block_list[-1].header_block.proof_of_space,
curr_difficulty,
block_list[-1].header_block.proof_of_time
.number_of_iterations,
test_constants["MIN_BLOCK_TIME"])
curr_ips = calculate_ips_from_iterations(
block_list[-1].header_block.proof_of_space,
curr_difficulty,
block_list[-1].header_block.proof_of_time.number_of_iterations,
test_constants["MIN_BLOCK_TIME"],
)
starting_height = block_list[-1].height + 1
timestamp = block_list[-1].header_block.header.data.timestamp
for next_height in range(starting_height, starting_height + num_blocks):
if (next_height > test_constants["DIFFICULTY_EPOCH"] and
next_height % test_constants["DIFFICULTY_EPOCH"] == test_constants["DIFFICULTY_DELAY"]):
if (
next_height > test_constants["DIFFICULTY_EPOCH"]
and next_height % test_constants["DIFFICULTY_EPOCH"]
== test_constants["DIFFICULTY_DELAY"]
):
# Calculates new difficulty
height1 = uint64(next_height - (test_constants["DIFFICULTY_EPOCH"] +
test_constants["DIFFICULTY_DELAY"]) - 1)
height1 = uint64(
next_height
- (
test_constants["DIFFICULTY_EPOCH"]
+ test_constants["DIFFICULTY_DELAY"]
)
- 1
)
height2 = uint64(next_height - (test_constants["DIFFICULTY_EPOCH"]) - 1)
height3 = uint64(next_height - (test_constants["DIFFICULTY_DELAY"]) - 1)
if height1 >= 0:
@ -119,8 +151,10 @@ class BlockTools:
else:
block1 = block_list[0]
assert block1.header_block.challenge
timestamp1 = (block1.header_block.header.data.timestamp -
test_constants["BLOCK_TIME_TARGET"])
timestamp1 = (
block1.header_block.header.data.timestamp
- test_constants["BLOCK_TIME_TARGET"]
)
iters1 = block1.header_block.challenge.total_iters
timestamp2 = block_list[height2].header_block.header.data.timestamp
timestamp3 = block_list[height3].header_block.header.data.timestamp
@ -128,42 +162,83 @@ class BlockTools:
block3 = block_list[height3]
assert block3.header_block.challenge
iters3 = block3.header_block.challenge.total_iters
term1 = (test_constants["DIFFICULTY_DELAY"] * prev_difficulty *
(timestamp3 - timestamp2) * test_constants["BLOCK_TIME_TARGET"])
term1 = (
test_constants["DIFFICULTY_DELAY"]
* prev_difficulty
* (timestamp3 - timestamp2)
* test_constants["BLOCK_TIME_TARGET"]
)
term2 = ((test_constants["DIFFICULTY_WARP_FACTOR"] - 1) *
(test_constants["DIFFICULTY_EPOCH"] - test_constants["DIFFICULTY_DELAY"]) * curr_difficulty
* (timestamp2 - timestamp1) * test_constants["BLOCK_TIME_TARGET"])
term2 = (
(test_constants["DIFFICULTY_WARP_FACTOR"] - 1)
* (
test_constants["DIFFICULTY_EPOCH"]
- test_constants["DIFFICULTY_DELAY"]
)
* curr_difficulty
* (timestamp2 - timestamp1)
* test_constants["BLOCK_TIME_TARGET"]
)
# Round down after the division
new_difficulty: uint64 = uint64((term1 + term2) //
(test_constants["DIFFICULTY_WARP_FACTOR"] *
(timestamp3 - timestamp2) *
(timestamp2 - timestamp1)))
new_difficulty: uint64 = uint64(
(term1 + term2)
// (
test_constants["DIFFICULTY_WARP_FACTOR"]
* (timestamp3 - timestamp2)
* (timestamp2 - timestamp1)
)
)
if new_difficulty >= curr_difficulty:
new_difficulty = min(new_difficulty, uint64(test_constants["DIFFICULTY_FACTOR"] *
curr_difficulty))
new_difficulty = min(
new_difficulty,
uint64(test_constants["DIFFICULTY_FACTOR"] * curr_difficulty),
)
else:
new_difficulty = max([uint64(1), new_difficulty,
uint64(curr_difficulty // test_constants["DIFFICULTY_FACTOR"])])
new_difficulty = max(
[
uint64(1),
new_difficulty,
uint64(
curr_difficulty // test_constants["DIFFICULTY_FACTOR"]
),
]
)
new_ips = uint64((iters3 - iters1)//(timestamp3 - timestamp1))
new_ips = uint64((iters3 - iters1) // (timestamp3 - timestamp1))
if new_ips >= curr_ips:
curr_ips = min(new_ips, uint64(test_constants["IPS_FACTOR"] * new_ips))
curr_ips = min(
new_ips, uint64(test_constants["IPS_FACTOR"] * new_ips)
)
else:
curr_ips = max([uint64(1), new_ips, uint64(curr_ips // test_constants["IPS_FACTOR"])])
curr_ips = max(
[
uint64(1),
new_ips,
uint64(curr_ips // test_constants["IPS_FACTOR"]),
]
)
prev_difficulty = curr_difficulty
curr_difficulty = new_difficulty
time_taken = seconds_per_block
timestamp += time_taken
block_list.append(self.create_next_block(test_constants, block_list[-1], timestamp, curr_difficulty,
curr_ips, seed))
block_list.append(
self.create_next_block(
test_constants,
block_list[-1],
timestamp,
curr_difficulty,
curr_ips,
seed,
)
)
return block_list
def create_genesis_block(self, input_constants: Dict, challenge_hash=bytes([0]*32),
seed: bytes = b'') -> FullBlock:
def create_genesis_block(
self, input_constants: Dict, challenge_hash=bytes([0] * 32), seed: bytes = b""
) -> FullBlock:
"""
Creates the genesis block with the specified details.
"""
@ -175,18 +250,24 @@ class BlockTools:
test_constants,
challenge_hash,
uint32(0),
bytes([0]*32),
bytes([0] * 32),
uint64(0),
uint64(0),
uint64(int(time.time())),
uint64(test_constants["DIFFICULTY_STARTING"]),
uint64(test_constants["VDF_IPS_STARTING"]),
seed
seed,
)
def create_next_block(self, input_constants: Dict, prev_block: FullBlock, timestamp: uint64,
difficulty: uint64, ips: uint64,
seed: bytes = b'') -> FullBlock:
def create_next_block(
self,
input_constants: Dict,
prev_block: FullBlock,
timestamp: uint64,
difficulty: uint64,
ips: uint64,
seed: bytes = b"",
) -> FullBlock:
"""
Creates the next block with the specified details.
"""
@ -206,12 +287,22 @@ class BlockTools:
timestamp,
uint64(difficulty),
ips,
seed
seed,
)
def _create_block(self, test_constants: Dict, challenge_hash: bytes32, height: uint32, prev_header_hash: bytes32,
prev_iters: uint64, prev_weight: uint64, timestamp: uint64, difficulty: uint64,
ips: uint64, seed: bytes) -> FullBlock:
def _create_block(
self,
test_constants: Dict,
challenge_hash: bytes32,
height: uint32,
prev_header_hash: bytes32,
prev_iters: uint64,
prev_weight: uint64,
timestamp: uint64,
difficulty: uint64,
ips: uint64,
seed: bytes,
) -> FullBlock:
"""
Creates a block with the specified details. Uses the stored plots to create a proof of space,
and also evaluates the VDF for the proof of time.
@ -238,38 +329,65 @@ class BlockTools:
raise NoProofsOfSpaceFound("No proofs for this challenge")
proof_xs: bytes = prover.get_full_proof(challenge_hash, 0)
proof_of_space: ProofOfSpace = ProofOfSpace(challenge_hash, pool_pk, plot_pk, k, [uint8(b) for b in proof_xs])
number_iters: uint64 = pot_iterations.calculate_iterations(proof_of_space,
difficulty, ips,
test_constants["MIN_BLOCK_TIME"])
proof_of_space: ProofOfSpace = ProofOfSpace(
challenge_hash, pool_pk, plot_pk, k, [uint8(b) for b in proof_xs]
)
number_iters: uint64 = pot_iterations.calculate_iterations(
proof_of_space, difficulty, ips, test_constants["MIN_BLOCK_TIME"]
)
disc: int = create_discriminant(challenge_hash, test_constants["DISCRIMINANT_SIZE_BITS"])
disc: int = create_discriminant(
challenge_hash, test_constants["DISCRIMINANT_SIZE_BITS"]
)
start_x: ClassGroup = ClassGroup.from_ab_discriminant(2, 1, disc)
y_cl, proof_bytes = create_proof_of_time_nwesolowski(
disc, start_x, number_iters, disc, n_wesolowski)
disc, start_x, number_iters, disc, n_wesolowski
)
output = ClassgroupElement(y_cl[0], y_cl[1])
proof_of_time = ProofOfTime(challenge_hash, number_iters, output, n_wesolowski, [uint8(b) for b in proof_bytes])
proof_of_time = ProofOfTime(
challenge_hash,
number_iters,
output,
n_wesolowski,
[uint8(b) for b in proof_bytes],
)
coinbase: CoinbaseInfo = CoinbaseInfo(height, block_rewards.calculate_block_reward(uint32(height)),
coinbase_target)
coinbase: CoinbaseInfo = CoinbaseInfo(
height,
block_rewards.calculate_block_reward(uint32(height)),
coinbase_target,
)
coinbase_sig: PrependSignature = pool_sk.sign_prepend(bytes(coinbase))
fees_target: FeesTarget = FeesTarget(fee_target, uint64(0))
solutions_generator: bytes32 = sha256(seed).digest()
cost = uint64(0)
body: Body = Body(coinbase, coinbase_sig, fees_target, None, solutions_generator, cost)
body: Body = Body(
coinbase, coinbase_sig, fees_target, None, solutions_generator, cost
)
header_data: HeaderData = HeaderData(prev_header_hash, timestamp, bytes([0]*32),
proof_of_space.get_hash(), body.get_hash(),
bytes([0]*32))
header_data: HeaderData = HeaderData(
prev_header_hash,
timestamp,
bytes([0] * 32),
proof_of_space.get_hash(),
body.get_hash(),
bytes([0] * 32),
)
header_hash_sig: PrependSignature = plot_sk.sign_prepend(header_data.get_hash())
header: Header = Header(header_data, header_hash_sig)
challenge = Challenge(challenge_hash, proof_of_space.get_hash(), proof_of_time.get_hash(), height,
uint64(prev_weight + difficulty), uint64(prev_iters + number_iters))
challenge = Challenge(
challenge_hash,
proof_of_space.get_hash(),
proof_of_time.get_hash(),
height,
uint64(prev_weight + difficulty),
uint64(prev_iters + number_iters),
)
header_block = HeaderBlock(proof_of_space, proof_of_time, challenge, header)
full_block: FullBlock = FullBlock(header_block, body)

View File

@ -2,18 +2,23 @@ from decimal import Decimal
from hashlib import sha256
from math import log
from src.consensus.pot_iterations import (_expected_plot_size,
_quality_to_decimal,
calculate_iterations_quality)
from src.consensus.pot_iterations import (
_expected_plot_size,
_quality_to_decimal,
calculate_iterations_quality,
)
from src.util.ints import uint8, uint64
class TestPotIterations():
class TestPotIterations:
def test_pade_approximation(self):
def test_approximation(input_dec, threshold):
bytes_input = int(Decimal(input_dec) * pow(2, 256)).to_bytes(32, "big")
print(_quality_to_decimal(bytes_input))
assert abs(1 - Decimal(-log(input_dec)) / _quality_to_decimal(bytes_input)) < threshold
assert (
abs(1 - Decimal(-log(input_dec)) / _quality_to_decimal(bytes_input))
< threshold
)
# The approximations become better the closer to 1 the input gets
test_approximation(0.7, 0.01)
@ -27,8 +32,20 @@ class TestPotIterations():
Tests that the percentage of blocks won is proportional to the space of each farmer,
with the assumption that all farmers have access to the same VDF speed.
"""
farmer_ks = [uint8(34), uint8(35), uint8(36), uint8(37), uint8(38), uint8(39), uint8(39),
uint8(39), uint8(39), uint8(39), uint8(40), uint8(41)]
farmer_ks = [
uint8(34),
uint8(35),
uint8(36),
uint8(37),
uint8(38),
uint8(39),
uint8(39),
uint8(39),
uint8(39),
uint8(39),
uint8(40),
uint8(41),
]
farmer_space = [_expected_plot_size(uint8(k)) for k in farmer_ks]
total_space = sum(farmer_space)
percentage_space = [float(sp / total_space) for sp in farmer_space]
@ -36,11 +53,20 @@ class TestPotIterations():
total_blocks = 5000
for b_index in range(total_blocks):
qualities = [sha256(b_index.to_bytes(32, "big") + bytes(farmer_index)).digest()
for farmer_index in range(len(farmer_ks))]
iters = [calculate_iterations_quality(qualities[i], farmer_ks[i], uint64(50000000),
uint64(5000), uint64(10))
for i in range(len(qualities))]
qualities = [
sha256(b_index.to_bytes(32, "big") + bytes(farmer_index)).digest()
for farmer_index in range(len(farmer_ks))
]
iters = [
calculate_iterations_quality(
qualities[i],
farmer_ks[i],
uint64(50000000),
uint64(5000),
uint64(10),
)
for i in range(len(qualities))
]
wins[iters.index(min(iters))] += 1
win_percentage = [wins[w] / total_blocks for w in range(len(farmer_ks))]

View File

@ -53,8 +53,14 @@ class TestGenesisBlock:
genesis_block = bc1.get_current_tips()[0]
assert genesis_block.height == 0
assert genesis_block.challenge
assert (await bc1.get_header_blocks_by_height([uint64(0)], genesis_block.header_hash))[0] == genesis_block
assert (await bc1.get_next_difficulty(genesis_block.header_hash)) == genesis_block.challenge.total_weight
assert (
await bc1.get_header_blocks_by_height(
[uint64(0)], genesis_block.header_hash
)
)[0] == genesis_block
assert (
await bc1.get_next_difficulty(genesis_block.header_hash)
) == genesis_block.challenge.total_weight
assert await bc1.get_next_ips(genesis_block.header_hash) > 0
@ -78,74 +84,99 @@ class TestBlockValidation:
@pytest.mark.asyncio
async def test_prev_pointer(self, initial_blockchain):
blocks, b = initial_blockchain
block_bad = FullBlock(HeaderBlock(
block_bad = FullBlock(
HeaderBlock(
blocks[9].header_block.proof_of_space,
blocks[9].header_block.proof_of_time,
blocks[9].header_block.challenge,
Header(HeaderData(
bytes([1]*32),
Header(
HeaderData(
bytes([1] * 32),
blocks[9].header_block.header.data.timestamp,
blocks[9].header_block.header.data.filter_hash,
blocks[9].header_block.header.data.proof_of_space_hash,
blocks[9].header_block.header.data.body_hash,
blocks[9].header_block.header.data.extension_data
), blocks[9].header_block.header.harvester_signature)
), blocks[9].body)
assert (await b.receive_block(block_bad)) == ReceiveBlockResult.DISCONNECTED_BLOCK
blocks[9].header_block.header.data.extension_data,
),
blocks[9].header_block.header.harvester_signature,
),
),
blocks[9].body,
)
assert (
await b.receive_block(block_bad)
) == ReceiveBlockResult.DISCONNECTED_BLOCK
@pytest.mark.asyncio
async def test_timestamp(self, initial_blockchain):
blocks, b = initial_blockchain
# Time too far in the past
block_bad = FullBlock(HeaderBlock(
block_bad = FullBlock(
HeaderBlock(
blocks[9].header_block.proof_of_space,
blocks[9].header_block.proof_of_time,
blocks[9].header_block.challenge,
Header(HeaderData(
Header(
HeaderData(
blocks[9].header_block.header.data.prev_header_hash,
blocks[9].header_block.header.data.timestamp - 1000,
blocks[9].header_block.header.data.filter_hash,
blocks[9].header_block.header.data.proof_of_space_hash,
blocks[9].header_block.header.data.body_hash,
blocks[9].header_block.header.data.extension_data
), blocks[9].header_block.header.harvester_signature)
), blocks[9].body)
blocks[9].header_block.header.data.extension_data,
),
blocks[9].header_block.header.harvester_signature,
),
),
blocks[9].body,
)
assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK
# Time too far in the future
block_bad = FullBlock(HeaderBlock(
block_bad = FullBlock(
HeaderBlock(
blocks[9].header_block.proof_of_space,
blocks[9].header_block.proof_of_time,
blocks[9].header_block.challenge,
Header(HeaderData(
Header(
HeaderData(
blocks[9].header_block.header.data.prev_header_hash,
uint64(int(time.time() + 3600 * 3)),
blocks[9].header_block.header.data.filter_hash,
blocks[9].header_block.header.data.proof_of_space_hash,
blocks[9].header_block.header.data.body_hash,
blocks[9].header_block.header.data.extension_data
), blocks[9].header_block.header.harvester_signature)
), blocks[9].body)
blocks[9].header_block.header.data.extension_data,
),
blocks[9].header_block.header.harvester_signature,
),
),
blocks[9].body,
)
assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK
@pytest.mark.asyncio
async def test_body_hash(self, initial_blockchain):
blocks, b = initial_blockchain
block_bad = FullBlock(HeaderBlock(
block_bad = FullBlock(
HeaderBlock(
blocks[9].header_block.proof_of_space,
blocks[9].header_block.proof_of_time,
blocks[9].header_block.challenge,
Header(HeaderData(
Header(
HeaderData(
blocks[9].header_block.header.data.prev_header_hash,
blocks[9].header_block.header.data.timestamp,
blocks[9].header_block.header.data.filter_hash,
blocks[9].header_block.header.data.proof_of_space_hash,
bytes([1]*32),
blocks[9].header_block.header.data.extension_data
), blocks[9].header_block.header.harvester_signature)
), blocks[9].body)
bytes([1] * 32),
blocks[9].header_block.header.data.extension_data,
),
blocks[9].header_block.header.harvester_signature,
),
),
blocks[9].body,
)
assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK
@ -153,14 +184,18 @@ class TestBlockValidation:
async def test_harvester_signature(self, initial_blockchain):
blocks, b = initial_blockchain
# Time too far in the past
block_bad = FullBlock(HeaderBlock(
block_bad = FullBlock(
HeaderBlock(
blocks[9].header_block.proof_of_space,
blocks[9].header_block.proof_of_time,
blocks[9].header_block.challenge,
Header(
blocks[9].header_block.header.data,
PrivateKey.from_seed(b'0').sign_prepend(b"random junk"))
), blocks[9].body)
blocks[9].header_block.header.data,
PrivateKey.from_seed(b"0").sign_prepend(b"random junk"),
),
),
blocks[9].body,
)
assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK
@pytest.mark.asyncio
@ -170,18 +205,21 @@ class TestBlockValidation:
bad_pos = blocks[9].header_block.proof_of_space.proof
bad_pos[0] = (bad_pos[0] + 1) % 256
# Proof of space invalid
block_bad = FullBlock(HeaderBlock(
block_bad = FullBlock(
HeaderBlock(
ProofOfSpace(
blocks[9].header_block.proof_of_space.challenge_hash,
blocks[9].header_block.proof_of_space.pool_pubkey,
blocks[9].header_block.proof_of_space.plot_pubkey,
blocks[9].header_block.proof_of_space.size,
bad_pos
bad_pos,
),
blocks[9].header_block.proof_of_time,
blocks[9].header_block.challenge,
blocks[9].header_block.header
), blocks[9].body)
blocks[9].header_block.header,
),
blocks[9].body,
)
assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK
@pytest.mark.asyncio
@ -189,7 +227,9 @@ class TestBlockValidation:
blocks, b = initial_blockchain
# Coinbase height invalid
block_bad = FullBlock(blocks[9].header_block, Body(
block_bad = FullBlock(
blocks[9].header_block,
Body(
CoinbaseInfo(
uint32(3),
blocks[9].body.coinbase.amount,
@ -199,8 +239,9 @@ class TestBlockValidation:
blocks[9].body.fees_target_info,
blocks[9].body.aggregated_signature,
blocks[9].body.solutions_generator,
blocks[9].body.cost
))
blocks[9].body.cost,
),
)
assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK
@pytest.mark.asyncio
@ -226,14 +267,24 @@ class TestBlockValidation:
assert diff_27 > diff_26
assert (diff_27 / diff_26) <= test_constants["DIFFICULTY_FACTOR"]
assert (await b.get_next_ips(blocks[1].header_hash)) == constants["VDF_IPS_STARTING"]
assert (await b.get_next_ips(blocks[24].header_hash)) == (await b.get_next_ips(blocks[23].header_hash))
assert (await b.get_next_ips(blocks[25].header_hash)) == (await b.get_next_ips(blocks[24].header_hash))
assert (await b.get_next_ips(blocks[26].header_hash)) > (await b.get_next_ips(blocks[25].header_hash))
assert (await b.get_next_ips(blocks[27].header_hash)) == (await b.get_next_ips(blocks[26].header_hash))
assert (await b.get_next_ips(blocks[1].header_hash)) == constants[
"VDF_IPS_STARTING"
]
assert (await b.get_next_ips(blocks[24].header_hash)) == (
await b.get_next_ips(blocks[23].header_hash)
)
assert (await b.get_next_ips(blocks[25].header_hash)) == (
await b.get_next_ips(blocks[24].header_hash)
)
assert (await b.get_next_ips(blocks[26].header_hash)) > (
await b.get_next_ips(blocks[25].header_hash)
)
assert (await b.get_next_ips(blocks[27].header_hash)) == (
await b.get_next_ips(blocks[26].header_hash)
)
class TestReorgs():
class TestReorgs:
@pytest.mark.asyncio
async def test_basic_reorg(self):
blocks = bt.get_consecutive_blocks(test_constants, 100, [], 9)

View File

@ -51,5 +51,5 @@ class TestStreamable(unittest.TestCase):
pass
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -3,8 +3,11 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from src.util.ints import uint8
from src.util.type_checking import (is_type_List, is_type_SpecificOptional,
strictdataclass)
from src.util.type_checking import (
is_type_List,
is_type_SpecificOptional,
strictdataclass,
)
class TestIsTypeList(unittest.TestCase):
@ -100,5 +103,5 @@ class TestStrictClass(unittest.TestCase):
A()
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()