Compare commits

...

3 Commits

Author SHA1 Message Date
Daira Emma Hopwood 19c30cf11a Finish the implementation of Streamlet and add tests.
Signed-off-by: Daira Emma Hopwood <daira@jacaranda.org>
2023-12-13 21:46:47 +00:00
Daira Emma Hopwood 6bdd41f1c8 Re-add __init__.py files.
Signed-off-by: Daira Emma Hopwood <daira@jacaranda.org>
2023-12-13 21:46:47 +00:00
Daira Emma Hopwood 1af789f04e __init__.py files should not contain significant code.
Signed-off-by: Daira Emma Hopwood <daira@jacaranda.org>
2023-12-13 21:46:23 +00:00
10 changed files with 913 additions and 658 deletions

View File

@ -14,301 +14,3 @@ collects the fees paid by the other transactions in the block.
The simulation of the shielded protocol does not attempt to model any
actual privacy properties.
"""
from __future__ import annotations
from typing import Iterable, Optional
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from collections import deque
from itertools import chain, islice
from sys import version_info
from ..util import Unique
class BlockHash(Unique):
"""Unique value representing a best-chain block hash."""
pass
class BCTransaction:
"""A transaction for a best-chain protocol."""
@dataclass(frozen=True)
class _TXO:
tx: BCTransaction
index: int
value: int
@dataclass(eq=False)
class _Note(Unique):
"""
A shielded note. Unlike in the actual protocol, we conflate notes, note
commitments, and nullifiers. This will be sufficient because we don't
need to maintain any actual privacy.
This is not a frozen dataclass; its identity is important, and models the
fact that each note has a unique commitment and nullifier in the actual
protocol.
"""
value: int
def __init__(self,
transparent_inputs: Sequence[BCTransaction._TXO],
transparent_output_values: Sequence[int],
shielded_inputs: Sequence[BCTransaction._Note],
shielded_output_values: Sequence[int],
fee: int,
anchor: Optional[BCContext]=None,
issuance: int=0):
"""
Constructs a `BCTransaction` with the given transparent inputs, transparent
output values, anchor, shielded inputs, shielded output values, fee, and
(if it is a coinbase transaction) issuance.
The elements of `transparent_inputs` are TXO objects obtained from the
`transparent_output` method of another `BCTransaction`. The elements of
`shielded_inputs` are Note objects obtained from the `shielded_output`
method of another `BCTransaction`. The TXO and Note classes are private,
and these objects should not be constructed directly.
The anchor is modelled as a `BCContext` such that
`anchor.can_spend(shielded_inputs)`. If there are no shielded inputs,
`anchor` must be `None`. The anchor object must not be modified after
passing it to this constructor (copy it if necessary).
For a coinbase transaction, pass `[]` for `transparent_inputs` and
`shielded_inputs`, and pass `fee` as a negative value of magnitude equal
to the total amount of fees paid by other transactions in the block.
"""
assert issuance >= 0
coinbase = len(transparent_inputs) + len(shielded_inputs) == 0
assert fee >= 0 or coinbase
assert issuance == 0 or coinbase
assert all((v >= 0 for v in chain(transparent_output_values, shielded_output_values)))
assert (
sum((txin.value for txin in transparent_inputs))
+ sum((note.value for note in shielded_inputs))
+ issuance ==
sum(transparent_output_values)
+ sum(shielded_output_values)
+ fee
)
assert anchor is None if len(shielded_inputs) == 0 else (
anchor is not None and anchor.can_spend(shielded_inputs))
self.transparent_inputs = transparent_inputs
self.transparent_outputs = [self._TXO(self, i, v)
for (i, v) in enumerate(transparent_output_values)]
self.shielded_inputs = shielded_inputs
self.shielded_outputs = [self._Note(v) for v in shielded_output_values]
self.fee = fee
self.anchor = anchor
self.issuance = issuance
def transparent_input(self, index: int) -> BCTransaction._TXO:
"""Returns the transparent input TXO with the given index."""
return self.transparent_inputs[index]
def transparent_output(self, index: int) -> BCTransaction._TXO:
"""Returns the transparent output TXO with the given index."""
return self.transparent_outputs[index]
def shielded_input(self, index: int) -> BCTransaction._Note:
"""Returns the shielded input note with the given index."""
return self.shielded_inputs[index]
def shielded_output(self, index: int) -> BCTransaction._Note:
"""Returns the shielded output note with the given index."""
return self.shielded_outputs[index]
def is_coinbase(self) -> bool:
"""
Returns `True` if this is a coinbase transaction (it has no inputs).
"""
return len(self.transparent_inputs) + len(self.shielded_inputs) == 0
class Spentness(Enum):
"""The spentness status of a note."""
Unspent = auto()
"""The note is unspent."""
Spent = auto()
"""The note is spent."""
class BCContext:
"""
A context that allows checking transactions for contextual validity in a
best-chain protocol.
"""
assert version_info >= (3, 7), "This code relies on insertion-ordered dicts."
def __init__(self):
"""Constructs an empty `BCContext`."""
self.transactions: deque[BCTransaction] = deque()
self.utxo_set: set[BCTransaction._TXO] = set()
# Since dicts are insertion-ordered, this models the sequence in which
# notes are committed as well as their spentness.
self.notes: dict[BCTransaction._Note, Spentness] = {}
self.total_issuance = 0
def committed_notes(self) -> list[(BCTransaction._Note, Spentness)]:
"""
Returns a list of (`Note`, `Spentness`) for notes added to this context,
preserving the commitment order.
"""
return list(self.notes.items())
def can_spend(self, tospend: Iterable[BCTransaction._Note]) -> bool:
"""Can all of the notes in `tospend` be spent in this context?"""
return all((self.notes.get(note) == Spentness.Unspent for note in tospend))
def _check(self, tx: BCTransaction) -> tuple[bool, set[BCTransaction._TXO]]:
"""
Checks whether `tx` is valid. To avoid recomputation, this returns
a pair of the validity, and the set of transparent inputs of `tx`.
"""
txins = set(tx.transparent_inputs)
valid = txins.issubset(self.utxo_set) and self.can_spend(tx.shielded_inputs)
return (valid, txins)
def is_valid(self, tx: BCTransaction) -> bool:
"""Is `tx` valid in this context?"""
return self._check(tx)[0]
def add_if_valid(self, tx: BCTransaction) -> bool:
"""
If `tx` is valid in this context, add it to the context and return `True`.
Otherwise leave the context unchanged and return `False`.
"""
(valid, txins) = self._check(tx)
if valid:
self.utxo_set -= txins
self.utxo_set |= set(tx.transparent_outputs)
for note in tx.shielded_inputs:
self.notes[note] = Spentness.Spent
for note in tx.shielded_outputs:
assert note not in self.notes
self.notes[note] = Spentness.Unspent
self.total_issuance += tx.issuance
self.transactions.append(tx)
return valid
def copy(self) -> BCContext:
"""Returns an independent copy of this `BCContext`."""
ctx = BCContext()
ctx.transactions = self.transactions.copy()
ctx.utxo_set = self.utxo_set.copy()
ctx.notes = self.notes.copy()
ctx.total_issuance = self.total_issuance
return ctx
class BCBlock:
"""A block in a best-chain protocol."""
def __init__(self,
parent: Optional[BCBlock],
added_score: int,
transactions: Sequence[BCTransaction],
allow_invalid: bool=False):
"""
Constructs a `BCBlock` with the given parent block, score relative to the
parent, and sequence of transactions. `transactions` must not be modified
after passing it to this constructor (copy it if necessary).
If `allow_invalid` is set, the block need not be valid.
Use `parent=None` to construct the genesis block.
"""
self.parent = parent
self.score = added_score
if self.parent is not None:
self.score += self.parent.score
self.transactions = transactions
self.hash = BlockHash()
if not allow_invalid:
self.assert_noncontextually_valid()
def assert_noncontextually_valid(self) -> None:
"""Assert that non-contextual consensus rules are satisfied for this block."""
assert len(self.transactions) > 0
assert self.transactions[0].is_coinbase()
assert not any((tx.is_coinbase() for tx in islice(self.transactions, 1, None)))
assert sum((tx.fee for tx in self.transactions)) == 0
def is_noncontextually_valid(self) -> bool:
"""Are non-contextual consensus rules satisfied for this block?"""
try:
self.assert_noncontextually_valid()
return True
except AssertionError:
return False
@dataclass
class BCProtocol:
"""A best-chain protocol."""
Transaction: type[object] = BCTransaction
"""The type of transactions for this protocol."""
Context: type[object] = BCContext
"""The type of contexts for this protocol."""
Block: type[object] = BCBlock
"""The type of blocks for this protocol."""
__all__ = ['BCTransaction', 'BCContext', 'BCBlock', 'BCProtocol', 'BlockHash', 'Spentness']
import unittest
class TestBC(unittest.TestCase):
def test_basic(self) -> None:
ctx = BCContext()
coinbase_tx0 = BCTransaction([], [10], [], [], 0, issuance=10)
self.assertTrue(ctx.add_if_valid(coinbase_tx0))
genesis = BCBlock(None, 1, [coinbase_tx0])
self.assertEqual(genesis.score, 1)
self.assertEqual(ctx.total_issuance, 10)
coinbase_tx1 = BCTransaction([], [6], [], [], -1, issuance=5)
spend_tx = BCTransaction([coinbase_tx0.transparent_output(0)], [9], [], [], 1)
self.assertTrue(ctx.add_if_valid(coinbase_tx1))
self.assertTrue(ctx.add_if_valid(spend_tx))
block1 = BCBlock(genesis, 1, [coinbase_tx1, spend_tx])
self.assertEqual(block1.score, 2)
self.assertEqual(ctx.total_issuance, 15)
coinbase_tx2 = BCTransaction([], [6], [], [], -1, issuance=5)
shielding_tx = BCTransaction([coinbase_tx1.transparent_output(0), spend_tx.transparent_output(0)],
[], [], [8, 6], 1)
self.assertTrue(ctx.add_if_valid(coinbase_tx2))
self.assertTrue(ctx.add_if_valid(shielding_tx))
block2 = BCBlock(block1, 2, [coinbase_tx2, shielding_tx])
block2_anchor = ctx.copy()
self.assertEqual(block2.score, 4)
self.assertEqual(ctx.total_issuance, 20)
coinbase_tx3 = BCTransaction([], [7], [], [], -2, issuance=5)
shielded_tx = BCTransaction([], [], [shielding_tx.shielded_output(0)], [7], 1,
anchor=block2_anchor)
deshielding_tx = BCTransaction([], [5], [shielding_tx.shielded_output(1)], [], 1,
anchor=block2_anchor)
self.assertTrue(ctx.add_if_valid(coinbase_tx3))
self.assertTrue(ctx.add_if_valid(shielded_tx))
self.assertTrue(ctx.add_if_valid(deshielding_tx))
block3 = BCBlock(block2, 3, [coinbase_tx3, shielded_tx, deshielding_tx])
self.assertEqual(block3.score, 7)
self.assertEqual(ctx.total_issuance, 25)

301
simtfl/bc/chain.py Normal file
View File

@ -0,0 +1,301 @@
"""
Abstractions for best-chain transactions, contexts, and blocks.
"""
from __future__ import annotations
from typing import Iterable, Optional, TypeAlias
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from collections import deque
from itertools import chain, islice
from sys import version_info
from ..util import Unique
class BlockHash(Unique):
"""Unique value representing a best-chain block hash."""
pass
class BCTransaction:
"""A transaction for a best-chain protocol."""
@dataclass(frozen=True)
class _TXO:
tx: BCTransaction
index: int
value: int
@dataclass(eq=False)
class _Note(Unique):
"""
A shielded note. Unlike in the actual protocol, we conflate notes, note
commitments, and nullifiers. This will be sufficient because we don't
need to maintain any actual privacy.
This is not a frozen dataclass; its identity is important, and models the
fact that each note has a unique commitment and nullifier in the actual
protocol.
"""
value: int
def __init__(self,
transparent_inputs: Sequence[BCTransaction._TXO],
transparent_output_values: Sequence[int],
shielded_inputs: Sequence[BCTransaction._Note],
shielded_output_values: Sequence[int],
fee: int,
anchor: Optional[BCContext]=None,
issuance: int=0):
"""
Constructs a `BCTransaction` with the given transparent inputs, transparent
output values, anchor, shielded inputs, shielded output values, fee, and
(if it is a coinbase transaction) issuance.
The elements of `transparent_inputs` are TXO objects obtained from the
`transparent_output` method of another `BCTransaction`. The elements of
`shielded_inputs` are Note objects obtained from the `shielded_output`
method of another `BCTransaction`. The TXO and Note classes are private,
and these objects should not be constructed directly.
The anchor is modelled as a `BCContext` such that
`anchor.can_spend(shielded_inputs)`. If there are no shielded inputs,
`anchor` must be `None`. The anchor object must not be modified after
passing it to this constructor (copy it if necessary).
For a coinbase transaction, pass `[]` for `transparent_inputs` and
`shielded_inputs`, and pass `fee` as a negative value of magnitude equal
to the total amount of fees paid by other transactions in the block.
"""
assert issuance >= 0
coinbase = len(transparent_inputs) + len(shielded_inputs) == 0
assert fee >= 0 or coinbase
assert issuance == 0 or coinbase
assert all((v >= 0 for v in chain(transparent_output_values, shielded_output_values)))
assert (
sum((txin.value for txin in transparent_inputs))
+ sum((note.value for note in shielded_inputs))
+ issuance ==
sum(transparent_output_values)
+ sum(shielded_output_values)
+ fee
)
assert anchor is None if len(shielded_inputs) == 0 else (
anchor is not None and anchor.can_spend(shielded_inputs))
self.transparent_inputs = transparent_inputs
self.transparent_outputs = [self._TXO(self, i, v)
for (i, v) in enumerate(transparent_output_values)]
self.shielded_inputs = shielded_inputs
self.shielded_outputs = [self._Note(v) for v in shielded_output_values]
self.fee = fee
self.anchor = anchor
self.issuance = issuance
def transparent_input(self, index: int) -> BCTransaction._TXO:
"""Returns the transparent input TXO with the given index."""
return self.transparent_inputs[index]
def transparent_output(self, index: int) -> BCTransaction._TXO:
"""Returns the transparent output TXO with the given index."""
return self.transparent_outputs[index]
def shielded_input(self, index: int) -> BCTransaction._Note:
"""Returns the shielded input note with the given index."""
return self.shielded_inputs[index]
def shielded_output(self, index: int) -> BCTransaction._Note:
"""Returns the shielded output note with the given index."""
return self.shielded_outputs[index]
def is_coinbase(self) -> bool:
"""
Returns `True` if this is a coinbase transaction (it has no inputs).
"""
return len(self.transparent_inputs) + len(self.shielded_inputs) == 0
class Spentness(Enum):
"""The spentness status of a note."""
Unspent = auto()
"""The note is unspent."""
Spent = auto()
"""The note is spent."""
class BCContext:
"""
A context that allows checking transactions for contextual validity in a
best-chain protocol.
"""
assert version_info >= (3, 7), "This code relies on insertion-ordered dicts."
def __init__(self):
"""Constructs an empty `BCContext`."""
self.transactions: deque[BCTransaction] = deque()
self.utxo_set: set[BCTransaction._TXO] = set()
# Since dicts are insertion-ordered, this models the sequence in which
# notes are committed as well as their spentness.
self.notes: dict[BCTransaction._Note, Spentness] = {}
self.total_issuance = 0
def committed_notes(self) -> list[(BCTransaction._Note, Spentness)]:
"""
Returns a list of (`Note`, `Spentness`) for notes added to this context,
preserving the commitment order.
"""
return list(self.notes.items())
def can_spend(self, tospend: Iterable[BCTransaction._Note]) -> bool:
"""Can all of the notes in `tospend` be spent in this context?"""
return all((self.notes.get(note) == Spentness.Unspent for note in tospend))
def _check(self, tx: BCTransaction) -> tuple[bool, set[BCTransaction._TXO]]:
"""
Checks whether `tx` is valid. To avoid recomputation, this returns
a pair of the validity, and the set of transparent inputs of `tx`.
"""
txins = set(tx.transparent_inputs)
valid = txins.issubset(self.utxo_set) and self.can_spend(tx.shielded_inputs)
return (valid, txins)
def is_valid(self, tx: BCTransaction) -> bool:
"""Is `tx` valid in this context?"""
return self._check(tx)[0]
def add_if_valid(self, tx: BCTransaction) -> bool:
"""
If `tx` is valid in this context, add it to the context and return `True`.
Otherwise leave the context unchanged and return `False`.
"""
(valid, txins) = self._check(tx)
if valid:
self.utxo_set -= txins
self.utxo_set |= set(tx.transparent_outputs)
for note in tx.shielded_inputs:
self.notes[note] = Spentness.Spent
for note in tx.shielded_outputs:
assert note not in self.notes
self.notes[note] = Spentness.Unspent
self.total_issuance += tx.issuance
self.transactions.append(tx)
return valid
def copy(self) -> BCContext:
"""Returns an independent copy of this `BCContext`."""
ctx = BCContext()
ctx.transactions = self.transactions.copy()
ctx.utxo_set = self.utxo_set.copy()
ctx.notes = self.notes.copy()
ctx.total_issuance = self.total_issuance
return ctx
class BCBlock:
"""A block in a best-chain protocol."""
def __init__(self,
parent: Optional[BCBlock],
added_score: int,
transactions: Sequence[BCTransaction],
allow_invalid: bool=False):
"""
Constructs a `BCBlock` with the given parent block, score relative to the
parent, and sequence of transactions. `transactions` must not be modified
after passing it to this constructor (copy it if necessary).
If `allow_invalid` is set, the block need not be valid.
Use `parent=None` to construct the genesis block.
"""
self.parent = parent
self.score = added_score
if self.parent is not None:
self.score += self.parent.score
self.transactions = transactions
self.hash = BlockHash()
if not allow_invalid:
self.assert_noncontextually_valid()
def assert_noncontextually_valid(self) -> None:
"""Assert that non-contextual consensus rules are satisfied for this block."""
assert len(self.transactions) > 0
assert self.transactions[0].is_coinbase()
assert not any((tx.is_coinbase() for tx in islice(self.transactions, 1, None)))
assert sum((tx.fee for tx in self.transactions)) == 0
def is_noncontextually_valid(self) -> bool:
"""Are non-contextual consensus rules satisfied for this block?"""
try:
self.assert_noncontextually_valid()
return True
except AssertionError:
return False
@dataclass
class BCProtocol:
"""A best-chain protocol."""
Transaction: TypeAlias = BCTransaction
"""The type of transactions for this protocol."""
Context: TypeAlias = BCContext
"""The type of contexts for this protocol."""
Block: TypeAlias = BCBlock
"""The type of blocks for this protocol."""
__all__ = ['BCTransaction', 'BCContext', 'BCBlock', 'BCProtocol', 'BlockHash', 'Spentness']
import unittest
class TestBC(unittest.TestCase):
def test_basic(self) -> None:
ctx = BCContext()
coinbase_tx0 = BCTransaction([], [10], [], [], 0, issuance=10)
self.assertTrue(ctx.add_if_valid(coinbase_tx0))
genesis = BCBlock(None, 1, [coinbase_tx0])
self.assertEqual(genesis.score, 1)
self.assertEqual(ctx.total_issuance, 10)
coinbase_tx1 = BCTransaction([], [6], [], [], -1, issuance=5)
spend_tx = BCTransaction([coinbase_tx0.transparent_output(0)], [9], [], [], 1)
self.assertTrue(ctx.add_if_valid(coinbase_tx1))
self.assertTrue(ctx.add_if_valid(spend_tx))
block1 = BCBlock(genesis, 1, [coinbase_tx1, spend_tx])
self.assertEqual(block1.score, 2)
self.assertEqual(ctx.total_issuance, 15)
coinbase_tx2 = BCTransaction([], [6], [], [], -1, issuance=5)
shielding_tx = BCTransaction([coinbase_tx1.transparent_output(0), spend_tx.transparent_output(0)],
[], [], [8, 6], 1)
self.assertTrue(ctx.add_if_valid(coinbase_tx2))
self.assertTrue(ctx.add_if_valid(shielding_tx))
block2 = BCBlock(block1, 2, [coinbase_tx2, shielding_tx])
block2_anchor = ctx.copy()
self.assertEqual(block2.score, 4)
self.assertEqual(ctx.total_issuance, 20)
coinbase_tx3 = BCTransaction([], [7], [], [], -2, issuance=5)
shielded_tx = BCTransaction([], [], [shielding_tx.shielded_output(0)], [7], 1,
anchor=block2_anchor)
deshielding_tx = BCTransaction([], [5], [shielding_tx.shielded_output(1)], [], 1,
anchor=block2_anchor)
self.assertTrue(ctx.add_if_valid(coinbase_tx3))
self.assertTrue(ctx.add_if_valid(shielded_tx))
self.assertTrue(ctx.add_if_valid(deshielding_tx))
block3 = BCBlock(block2, 3, [coinbase_tx3, shielded_tx, deshielding_tx])
self.assertEqual(block3.score, 7)
self.assertEqual(ctx.total_issuance, 25)

View File

@ -7,168 +7,6 @@ modified in [Crosslink]). It might not be sufficient for other BFT
protocols but that's okay; it's a prototype.
[CS2020] https://eprint.iacr.org/2020/088.pdf
[Crosslink] https://hackmd.io/JqENg--qSmyqRt_RqY7Whw?view
"""
from __future__ import annotations
def two_thirds_threshold(n: int) -> int:
"""
Calculate the notarization threshold used in most permissioned BFT protocols:
`ceiling(n * 2/3)`.
"""
return (n * 2 + 2) // 3
class PermissionedBFTBase:
"""
This class is used for the genesis block in a permissioned BFT protocol
(which is taken to be notarized, and therefore valid, by definition).
It is also used as a base class for other BFT block and proposal classes.
"""
def __init__(self, n: int, t: int):
"""
Constructs a genesis block for a permissioned BFT protocol with
`n` nodes, of which at least `t` must sign each proposal.
"""
self.n = n
self.t = t
self.parent = None
def last_final(self) -> PermissionedBFTBase:
"""
Returns the last final block in this block's ancestor chain.
For the genesis block, this is itself.
"""
return self
class PermissionedBFTBlock(PermissionedBFTBase):
"""
A block for a BFT protocol. Each non-genesis block is based on a
notarized proposal, and in practice consists of the proposer's signature
over the notarized proposal.
Honest proposers must only ever sign at most one valid proposal for the
given epoch in which they are a proposer.
BFT blocks are taken to be notarized, and therefore valid, by definition.
"""
def __init__(self, proposal: PermissionedBFTProposal):
"""Constructs a `PermissionedBFTBlock` for the given proposal."""
super().__init__(proposal.n, proposal.t)
proposal.assert_notarized()
self.proposal = proposal
self.parent = proposal.parent
def last_final(self):
"""
Returns the last final block in this block's ancestor chain.
This should be overridden by subclasses; the default implementation
will (inefficiently) just return the genesis block.
"""
return self if self.parent is None else self.parent.last_final()
class PermissionedBFTProposal(PermissionedBFTBase):
"""A proposal for a BFT protocol."""
def __init__(self, parent: PermissionedBFTBase):
"""
Constructs a `PermissionedBFTProposal` with the given parent
`PermissionedBFTBlock`. The parameters are determined by the parent
block.
"""
super().__init__(parent.n, parent.t)
self.parent = parent
self.signers = set()
def assert_valid(self) -> None:
"""
Assert that this proposal is valid. This does not assert that it is
notarized. This should be overridden by subclasses.
"""
pass
def is_valid(self) -> bool:
"""Is this proposal valid?"""
try:
self.assert_valid()
return True
except AssertionError:
return False
def assert_notarized(self) -> None:
"""
Assert that this proposal is notarized. A `PermissionedBFTProposal`
is notarized iff it is valid and has at least the threshold number of
signatures.
"""
self.assert_valid()
assert len(self.signers) >= self.t
def is_notarized(self) -> bool:
"""Is this proposal notarized?"""
try:
self.assert_notarized()
return True
except AssertionError:
return False
def add_signature(self, index: int) -> None:
"""
Record that the node with the given `index` has signed this proposal.
If the same node signs more than once, the subsequent signatures are
ignored.
"""
self.signers.add(index)
assert len(self.signers) <= self.n
__all__ = ['two_thirds_threshold', 'PermissionedBFTBase', 'PermissionedBFTBlock', 'PermissionedBFTProposal']
import unittest
class TestPermissionedBFT(unittest.TestCase):
def test_basic(self) -> None:
# Construct the genesis block.
genesis = PermissionedBFTBase(5, 2)
current = genesis
self.assertEqual(current.last_final(), genesis)
for _ in range(2):
proposal = PermissionedBFTProposal(current)
proposal.assert_valid()
self.assertTrue(proposal.is_valid())
self.assertFalse(proposal.is_notarized())
# not enough signatures
proposal.add_signature(0)
self.assertFalse(proposal.is_notarized())
# same index, so we still only have one signature
proposal.add_signature(0)
self.assertFalse(proposal.is_notarized())
# different index, now we have two signatures as required
proposal.add_signature(1)
proposal.assert_notarized()
self.assertTrue(proposal.is_notarized())
current = PermissionedBFTBlock(proposal)
self.assertEqual(current.last_final(), genesis)
def test_assertions(self) -> None:
genesis = PermissionedBFTBase(5, 2)
proposal = PermissionedBFTProposal(genesis)
self.assertRaises(AssertionError, PermissionedBFTBlock, proposal)
proposal.add_signature(0)
self.assertRaises(AssertionError, PermissionedBFTBlock, proposal)
proposal.add_signature(1)
_ = PermissionedBFTBlock(proposal)

209
simtfl/bft/chain.py Normal file
View File

@ -0,0 +1,209 @@
"""
Abstractions for Byzantine Fault-Tolerant protocols.
"""
from __future__ import annotations
def two_thirds_threshold(n: int) -> int:
"""
Calculate the notarization threshold used in most permissioned BFT protocols:
`ceiling(n * 2/3)`.
"""
return (n * 2 + 2) // 3
class PermissionedBFTBase:
"""
This class is used for the genesis block in a permissioned BFT protocol
(which is taken to be notarized, and therefore valid, by definition).
It is also used as a base class for other BFT block and proposal classes.
"""
def __init__(self, n: int, t: int):
"""
Constructs a genesis block for a permissioned BFT protocol with
`n` nodes, of which at least `t` must sign each proposal.
"""
self.n = n
"""The number of voters."""
self.t = t
"""The threshold of votes required for notarization."""
self.parent = None
"""The genesis block has no parent (represented as `None`)."""
self.length = 1
"""The genesis chain length is 1."""
self.last_final = self
"""The last final block for the genesis block is itself."""
def preceq(self, other: PermissionedBFTBase):
"""Return True if this block is an ancestor of `other`."""
if self.length > other.length:
return False # optimization
return self == other or (other.parent is not None and self.preceq(other.parent))
def __eq__(self, other) -> bool:
return other.parent is None and (self.n, self.t) == (other.n, other.t)
def __hash__(self) -> int:
return hash((self.n, self.t))
class PermissionedBFTBlock(PermissionedBFTBase):
"""
A block for a BFT protocol. Each non-genesis block is based on a
notarized proposal, and in practice consists of the proposer's signature
over the notarized proposal.
Honest proposers must only ever sign at most one valid proposal for the
given epoch in which they are a proposer.
BFT blocks are taken to be notarized, and therefore valid, by definition.
"""
def __init__(self, proposal: PermissionedBFTProposal):
"""Constructs a `PermissionedBFTBlock` for the given proposal."""
super().__init__(proposal.n, proposal.t)
proposal.assert_notarized()
self.proposal = proposal
"""The proposal for this block."""
assert proposal.parent is not None
self.parent = proposal.parent
"""The parent of this block."""
self.length = proposal.length
"""The chain length of this block."""
self.last_final = self.parent.last_final
"""The last final block for this block."""
def __eq__(self, other) -> bool:
return (isinstance(other, PermissionedBFTBlock) and
(self.n, self.t, self.proposal) == (other.n, other.t, other.proposal))
def __hash__(self) -> int:
return hash((self.n, self.t, self.proposal))
class PermissionedBFTProposal(PermissionedBFTBase):
"""A proposal for a BFT protocol."""
def __init__(self, parent: PermissionedBFTBase):
"""
Constructs a `PermissionedBFTProposal` with the given parent
`PermissionedBFTBlock`. The parameters are determined by the parent
block.
"""
super().__init__(parent.n, parent.t)
self.parent = parent
"""The parent block of this proposal."""
self.length = parent.length + 1
"""The chain length of this proposal is one greater than its parent block."""
self.votes = set()
"""The set of voter indices that have voted for this proposal."""
def __eq__(self, other):
"""Two proposals are equal iff they are the same object."""
return self is other
def __hash__(self) -> int:
return id(self)
def assert_valid(self) -> None:
"""
Assert that this proposal is valid. This does not assert that it is
notarized. This should be overridden by subclasses.
"""
pass
def is_valid(self) -> bool:
"""Is this proposal valid?"""
try:
self.assert_valid()
return True
except AssertionError:
return False
def assert_notarized(self) -> None:
"""
Assert that this proposal is notarized. A `PermissionedBFTProposal`
is notarized iff it is valid and has at least the threshold number of
signatures.
"""
self.assert_valid()
assert len(self.votes) >= self.t
def is_notarized(self) -> bool:
"""Is this proposal notarized?"""
try:
self.assert_notarized()
return True
except AssertionError:
return False
def add_vote(self, index: int) -> None:
"""
Record that the node with the given `index` has voted for this proposal.
Calls that add the same vote more than once are ignored.
"""
self.votes.add(index)
assert len(self.votes) <= self.n
__all__ = ['two_thirds_threshold', 'PermissionedBFTBase', 'PermissionedBFTBlock', 'PermissionedBFTProposal']
import unittest
class TestPermissionedBFT(unittest.TestCase):
def test_basic(self) -> None:
# Construct the genesis block.
genesis = PermissionedBFTBase(5, 2)
current = genesis
self.assertEqual(current.last_final, genesis)
for _ in range(2):
parent = current
proposal = PermissionedBFTProposal(parent)
proposal.assert_valid()
self.assertTrue(proposal.is_valid())
self.assertFalse(proposal.is_notarized())
# not enough votes
proposal.add_vote(0)
self.assertFalse(proposal.is_notarized())
# same index, so we still only have one vote
proposal.add_vote(0)
self.assertFalse(proposal.is_notarized())
# different index, now we have two votes as required
proposal.add_vote(1)
proposal.assert_notarized()
self.assertTrue(proposal.is_notarized())
current = PermissionedBFTBlock(proposal)
self.assertTrue(parent.preceq(current))
self.assertFalse(current.preceq(parent))
self.assertNotEqual(current, parent)
self.assertEqual(current.last_final, genesis)
def test_assertions(self) -> None:
genesis = PermissionedBFTBase(5, 2)
proposal = PermissionedBFTProposal(genesis)
self.assertRaises(AssertionError, PermissionedBFTBlock, proposal)
proposal.add_vote(0)
self.assertRaises(AssertionError, PermissionedBFTBlock, proposal)
proposal.add_vote(1)
_ = PermissionedBFTBlock(proposal)

View File

@ -2,179 +2,6 @@
An implementation of adapted-Streamlet ([CS2020] as modified in [Crosslink]).
[CS2020] https://eprint.iacr.org/2020/088.pdf
[Crosslink] https://hackmd.io/JqENg--qSmyqRt_RqY7Whw?view
"""
from __future__ import annotations
from typing import Optional
from collections.abc import Sequence
from .. import PermissionedBFTBase, PermissionedBFTBlock, PermissionedBFTProposal, \
two_thirds_threshold
class StreamletProposal(PermissionedBFTProposal):
"""An adapted-Streamlet proposal."""
def __init__(self, parent: StreamletBlock | StreamletGenesis, epoch: int):
"""
Constructs a `StreamletProposal` with the given parent `StreamletBlock`,
for the given `epoch`. The parameters are determined by the parent block.
A proposal must be for an epoch after its parent's epoch.
"""
super().__init__(parent)
assert epoch > parent.epoch
self.epoch = epoch
"""The epoch of this proposal."""
def __repr__(self) -> str:
return "StreamletProposal(parent=%r, epoch=%r)" % (self.parent, self.epoch)
class StreamletGenesis(PermissionedBFTBase):
"""An adapted-Streamlet genesis block."""
def __init__(self, n: int):
"""
Constructs a genesis block for adapted-Streamlet with `n` nodes.
"""
super().__init__(n, two_thirds_threshold(n))
self.epoch = 0
"""The genesis block has epoch 0."""
def __repr__(self) -> str:
return "StreamletGenesis(n=%r)" % (self.n,)
class StreamletBlock(PermissionedBFTBlock):
"""
An adapted-Streamlet block. Each non-genesis Streamlet block is
based on a notarized `StreamletProposal`.
`StreamletBlock`s are taken to be notarized by definition.
All validity conditions are enforced in the contructor.
"""
def __init__(self, proposal: StreamletProposal):
"""Constructs a `StreamletBlock` for the given proposal."""
super().__init__(proposal)
self.epoch = proposal.epoch
def last_final(self) -> StreamletBlock | StreamletGenesis:
"""
Returns the last final block in this block's ancestor chain.
In Streamlet this is the middle block of the last group of three
that were proposed in consecutive epochs.
"""
last = self
if last.parent is None:
return last
middle = last.parent
if middle.parent is None:
return middle
first = middle.parent
while True:
if first.parent is None:
return first
if (first.epoch + 1, middle.epoch + 1) == (middle.epoch, last.epoch):
return middle
(first, middle, last) = (first.parent, first, middle)
def __repr__(self) -> str:
return "StreamletBlock(proposal=%r)" % (self.proposal,)
import unittest
from itertools import count
class TestStreamlet(unittest.TestCase):
def test_simple(self) -> None:
"""
Very simple example.
0 --- 1 --- 2 --- 3
"""
self._test_last_final([0, 1, 2], [0, 0, 2])
def test_figure_1(self) -> None:
"""
Figure 1: Streamlet finalization example (without the invalid 'X' proposal).
0 --- 2 --- 5 --- 6 --- 7
\
-- 1 --- 3
0 - Genesis
N - Notarized block
This diagram implies the epoch 6 block is the last-final block in the
context of the epoch 7 block, because it is in the middle of 3 blocks
with consecutive epoch numbers, and 6 is the most recent such block.
(We don't include the block/proposal with the red X because that's not
what we're testing.)
"""
self._test_last_final([0, 0, 1, None, 2, 5, 6], [0, 0, 0, 0, 0, 0, 6])
def test_complex(self) -> None:
"""
Safety Violation: due to three simultaneous properties:
- 6 is `last_final` in the context of 7
- 9 is `last_final` in the context of 10
- 9 is not a descendant of 6
0 --- 2 --- 5 --- 6 --- 7
\
-- 1 --- 3 --- 8 --- 9 --- 10
"""
self._test_last_final([0, 0, 1, None, 2, 5, 6, 3, 8, 9], [0, 0, 0, 0, 0, 0, 6, 0, 0, 9])
def _test_last_final(self, parent_map: Sequence[Optional[int]], final_map: Sequence[int]) -> None:
"""
This test constructs a tree of proposals with structure determined by
`parent_map`, and asserts `block.last_final()` matches the structure
determined by `final_map`.
parent_map: sequence of parent epoch numbers
final_map: sequence of final epoch numbers
"""
assert len(parent_map) == len(final_map)
# Construct the genesis block.
genesis = StreamletGenesis(3)
current = genesis
self.assertEqual(current.last_final(), genesis)
blocks = [genesis]
for (epoch, parent_epoch, final_epoch) in zip(count(1), parent_map, final_map):
if parent_epoch is None:
blocks.append(None)
continue
parent = blocks[parent_epoch]
assert parent is not None
proposal = StreamletProposal(parent, epoch)
proposal.assert_valid()
self.assertTrue(proposal.is_valid())
self.assertFalse(proposal.is_notarized())
# not enough signatures
proposal.add_signature(0)
self.assertFalse(proposal.is_notarized())
# same index, so we still only have one signature
proposal.add_signature(0)
self.assertFalse(proposal.is_notarized())
# different index, now we have two signatures as required
proposal.add_signature(1)
proposal.assert_notarized()
self.assertTrue(proposal.is_notarized())
current = StreamletBlock(proposal)
blocks.append(current)
self.assertEqual(current.last_final(), blocks[final_epoch])

View File

@ -0,0 +1,100 @@
"""
Adapted-Streamlet chain classes.
"""
from __future__ import annotations
from typing import Optional
from ..chain import PermissionedBFTBase, PermissionedBFTBlock, PermissionedBFTProposal, \
two_thirds_threshold
class StreamletProposal(PermissionedBFTProposal):
"""An adapted-Streamlet proposal."""
def __init__(self, parent: StreamletBlock | StreamletGenesis, epoch: int):
"""
Constructs a `StreamletProposal` with the given parent `StreamletBlock`,
for the given `epoch`. The parameters are determined by the parent block.
A proposal must be for an epoch after its parent's epoch.
"""
super().__init__(parent)
self.parent: StreamletBlock | StreamletGenesis = parent
assert epoch > parent.epoch
self.epoch = epoch
"""The epoch of this proposal."""
def __str__(self) -> str:
return f"StreamletProposal(parent={self.parent}, epoch={self.epoch}, length={self.length})"
class StreamletGenesis(PermissionedBFTBase):
"""An adapted-Streamlet genesis block."""
def __init__(self, n: int):
"""
Constructs a genesis block for adapted-Streamlet with `n` nodes.
"""
super().__init__(n, two_thirds_threshold(n))
self.parent: Optional[StreamletBlock | StreamletGenesis] = None
"""The genesis block has no parent (represented as `None`)."""
self.epoch = 0
"""The epoch of the genesis block is 0."""
self.last_final = self
"""The last final block of the genesis block is itself."""
def __str__(self) -> str:
return f"StreamletGenesis(n={self.n})"
def proposer_for_epoch(self, epoch: int):
assert epoch > 0
return (epoch - 1) % self.n
class StreamletBlock(PermissionedBFTBlock):
"""
An adapted-Streamlet block. Each non-genesis Streamlet block is
based on a notarized `StreamletProposal`.
`StreamletBlock`s are taken to be notarized by definition.
All validity conditions are enforced in the contructor.
"""
def __init__(self, proposal: StreamletProposal):
"""Constructs a `StreamletBlock` for the given proposal."""
super().__init__(proposal)
self.epoch = proposal.epoch
"""The epoch of this proposal."""
self.parent: StreamletBlock | StreamletGenesis = proposal.parent
self.last_final = self._compute_last_final()
"""
The last final block in this block's ancestor chain.
In Streamlet this is the middle block of the last group of three
that were proposed in consecutive epochs.
"""
def _compute_last_final(self) -> StreamletBlock | StreamletGenesis:
last: StreamletBlock | StreamletGenesis = self
if last.parent is None:
return last
middle: StreamletBlock | StreamletGenesis = last.parent
if middle.parent is None:
return middle
first: StreamletBlock | StreamletGenesis = middle.parent
while True:
if first.parent is None:
return first
if (first.epoch + 1, middle.epoch + 1) == (middle.epoch, last.epoch):
return middle
(first, middle, last) = (first.parent, first, middle)
def __str__(self) -> str:
return f"StreamletBlock(proposal={self.proposal})"

View File

@ -4,12 +4,15 @@ An adapted-Streamlet node.
from __future__ import annotations
from typing import Optional
from collections.abc import Sequence
from dataclasses import dataclass
from ...node import SequentialNode
from ...message import Message, PayloadMessage
from ...util import skip, ProcessEffect
from . import StreamletGenesis, StreamletBlock, StreamletProposal
from .chain import StreamletGenesis, StreamletBlock, StreamletProposal
class Echo(PayloadMessage):
@ -20,6 +23,35 @@ class Echo(PayloadMessage):
pass
@dataclass(frozen=True)
class Ballot(Message):
"""
A ballot message, recording that a voter has voted for a `StreamletProposal`.
Ballots should not be forged unless modelling an attack that allows doing so.
"""
proposal: StreamletProposal
"""The proposal."""
voter: int
"""The voter."""
def __str__(self) -> str:
return f"Ballot({self.proposal}, voter={self.voter})"
class Proposal(PayloadMessage):
"""
A message containing a `StreamletProposal`.
"""
pass
class Block(PayloadMessage):
"""
A message containing a `StreamletBlock`.
"""
pass
class StreamletNode(SequentialNode):
"""
A Streamlet node.
@ -32,7 +64,32 @@ class StreamletNode(SequentialNode):
"""
assert genesis.epoch == 0
self.genesis = genesis
"""The genesis block."""
self.voted_epoch = genesis.epoch
"""The last epoch on which this node voted."""
self.tip: StreamletBlock | StreamletGenesis = genesis
"""
A longest chain seen by this node. The node's last final block is given by
`self.tip.last_final`.
"""
self.proposal: Optional[StreamletProposal] = None
"""The current proposal by this node, when it is the proposer."""
self.safety_violations: set[tuple[StreamletBlock | StreamletGenesis,
StreamletBlock | StreamletGenesis]] = set()
"""The set of safety violations detected by this node."""
def propose(self, proposal: StreamletProposal) -> ProcessEffect:
"""
(process) Ask the node to make a proposal.
"""
assert proposal.is_valid()
assert proposal.epoch > self.voted_epoch
self.proposal = proposal
return self.broadcast(Proposal(proposal), False)
def handle(self, sender: int, message: Message) -> ProcessEffect:
"""
@ -42,32 +99,242 @@ class StreamletNode(SequentialNode):
(This causes the number of messages to blow up by a factor of `n`,
but it's what the Streamlet paper specifies and is necessary for
its liveness proof.)
* Received non-duplicate proposals may cause us to send a `Vote`.
* ...
* Receiving a non-duplicate `Proposal` may cause us to broadcast a `Ballot`.
* If we are the current proposer, keep track of ballots for our proposal.
* Receiving a `Block` may cause us to update our `tip`.
"""
if isinstance(message, Echo):
message = message.payload
else:
yield from self.broadcast(Echo(message))
yield from self.broadcast(Echo(message), False)
if isinstance(message, StreamletProposal):
yield from self.handle_proposal(message)
elif isinstance(message, StreamletBlock):
yield from self.handle_block(message)
if isinstance(message, Proposal):
yield from self.handle_proposal(message.payload)
elif isinstance(message, Block):
yield from self.handle_block(message.payload)
elif isinstance(message, Ballot):
yield from self.handle_ballot(message)
else:
yield from super().handle(sender, message)
def handle_proposal(self, proposal: StreamletProposal) -> ProcessEffect:
"""
(process) If we already voted in the epoch specified by the proposal or a
later epoch, ignore this proposal.
later epoch, ignore this proposal. Otherwise, cast a vote for it iff it
is valid.
"""
if proposal.epoch <= self.voted_epoch:
self.log("handle",
self.log("proposal",
f"received proposal for epoch {proposal.epoch} but we already voted in epoch {self.voted_epoch}")
return skip()
return skip()
if proposal.is_valid():
self.log("proposal", f"voting for {proposal}")
# For now we just forget that we made a proposal if we receive a different
# valid one from another node. This is not realistic. Note that we can and
# should vote for our own proposal.
if proposal != self.proposal:
self.proposal = None
self.voted_epoch = proposal.epoch
return self.broadcast(Ballot(proposal, self.ident), True)
else:
return skip()
def handle_block(self, block: StreamletBlock) -> ProcessEffect:
raise NotImplementedError
"""
If `block.last_final` does not descend from `self.tip.last_final`, reject the block.
(In this case, if also `self.tip.last_final` does not descend from `block.last_final`,
this is a detected safety violation.)
Otherwise, update `self.tip` to `block` iff `block` is later in lexicographic ordering
by `(length, epoch)`.
"""
if not self.tip.last_final.preceq(block.last_final):
self.log("block", f"× not ⪰ last_final: {block}")
if not block.last_final.preceq(self.tip.last_final):
self.log("block", f"! safety violation: ({block}, {self.tip})")
self.safety_violations.add((block, self.tip))
return skip()
# TODO: analyse tie-breaking rule.
if (self.tip.length, self.tip.epoch) >= (block.length, block.epoch):
self.log("block", f"× not updating tip: {block}")
return skip()
self.log("block", f"✓ updating tip: {block}")
self.tip = block
return skip()
def handle_ballot(self, ballot: Ballot) -> ProcessEffect:
"""
If we have made a proposal that is not yet notarized and the ballot is
for that proposal, add the vote. If it is now notarized, broadcast it
as a block.
"""
proposal = ballot.proposal
if proposal == self.proposal:
self.log("count", f"{ballot.voter} voted for our proposal in epoch {proposal.epoch}")
proposal.add_vote(ballot.voter)
if proposal.is_notarized():
yield from self.broadcast(Block(StreamletBlock(proposal)), True)
# It's fine to forget that we made the proposal now.
self.proposal = None
def final_block(self) -> StreamletBlock | StreamletGenesis:
"""
Return the last final block seen by this node.
"""
return self.tip.last_final
__all__ = ['Echo', 'Ballot', 'StreamletNode']
import unittest
from itertools import count
from simpy import Environment
from simpy.events import Process, Timeout
from ...network import Network
from ...logging import PrintLogger
class TestStreamlet(unittest.TestCase):
def test_simple(self) -> None:
"""
Very simple example.
0 --- 1 --- 2 --- 3
"""
self._test_last_final([0, 1, 2],
[0, 0, 2])
def test_figure_1(self) -> None:
"""
Figure 1: Streamlet finalization example (without the invalid 'X' proposal).
0 --- 2 --- 5 --- 6 --- 7
\
-- 1 --- 3
0 - Genesis
N - Notarized block
This diagram implies the epoch 6 block is the last-final block in the
context of the epoch 7 block, because it is in the middle of 3 blocks
with consecutive epoch numbers, and 6 is the most recent such block.
(We don't include the block/proposal with the red X because that's not
what we're testing.)
"""
N = None
self._test_last_final([0, 0, 1, N, 2, 5, 6],
[0, 0, 0, 0, 0, 0, 6])
def test_complex(self) -> None:
"""
Safety Violation: due to three simultaneous properties:
- 6 is `last_final` in the context of 7
- 9 is `last_final` in the context of 10
- 9 is not a descendant of 6
0 --- 2 --- 5 --- 6 --- 7
\
-- 1 --- 3 --- 8 --- 9 --- 10
"""
N = None
self._test_last_final([0, 0, 1, N, 2, 5, 6, 3, 8, 9],
[0, 0, 0, 0, 0, 0, 6, 0, 0, 9],
expect_divergence_at_epoch=8,
expect_safety_violations={(10, 7)})
def _test_last_final(self,
parent_map: Sequence[Optional[int]],
final_map: Sequence[int],
expect_divergence_at_epoch: Optional[int]=None,
expect_safety_violations: set[tuple[int, int]]=set()) -> None:
"""
This test constructs a tree of proposals with structure determined by
`parent_map`, and asserts `block.last_final` matches the structure
determined by `final_map`.
parent_map: sequence of parent epoch numbers
final_map: sequence of final epoch numbers
expect_divergence_at_epoch: first epoch at which a block does not become the new tip
expect_safety_violations: safety violation proofs
"""
assert len(parent_map) == len(final_map)
# Construct the genesis block.
genesis = StreamletGenesis(3)
network = Network(Environment(), logger=PrintLogger())
for _ in range(genesis.n):
network.add_node(StreamletNode(genesis))
current = genesis
self.assertEqual(current.last_final, genesis)
blocks: list[Optional[StreamletBlock | StreamletGenesis]] = [genesis]
def run() -> ProcessEffect:
for (epoch, parent_epoch, final_epoch) in zip(count(1), parent_map, final_map):
yield Timeout(network.env, 10)
if parent_epoch is None:
blocks.append(None)
continue
parent = blocks[parent_epoch]
assert parent is not None
proposer = network.node(genesis.proposer_for_epoch(epoch))
proposal = StreamletProposal(parent, epoch)
self.assertEqual(proposal.length, parent.length + 1)
proposal.assert_valid()
self.assertFalse(proposal.is_notarized())
proposer.propose(proposal)
yield Timeout(network.env, 10)
# The proposer should have sent the block.
assert proposer.proposal is None
# Make a fake block `current` from the proposal so that we can append
# it to `blocks` and check its `last_final`.
current = StreamletBlock(proposal)
self.assertEqual(current.length, proposal.length)
self.assertTrue(parent.preceq(current))
self.assertFalse(current.preceq(parent))
self.assertEqual(len(blocks), current.epoch)
blocks.append(current)
final_block = blocks[final_epoch]
assert final_block is not None
self.assertEqual(current.last_final, final_block)
# All nodes' tips should be the same.
tip = network.node(0).tip
for i in range(1, network.num_nodes()):
self.assertEqual(network.node(i).tip, tip)
# If we try to create a new block on top of a chain that is not the longest,
# the nodes will ignore it.
if epoch == expect_divergence_at_epoch:
self.assertLess(current.length, tip.length)
elif expect_divergence_at_epoch is None or epoch < expect_divergence_at_epoch:
self.assertEqual(current.length, tip.length)
self.assertEqual(tip.epoch, epoch)
self.assertEqual(tip.proposal, proposal)
for node in network.nodes:
node_final = node.final_block()
self.assertEqual(node_final, final_block,
f"epoch {node_final.epoch} != epoch {final_block.epoch}")
for node in network.nodes:
self.assertEqual(set(((a.epoch, b.epoch) for (a, b) in node.safety_violations)),
expect_safety_violations)
network.done = True
Process(network.env, run())
network.run_all()
self.assertTrue(network.done)

View File

@ -22,3 +22,6 @@ class PayloadMessage(Message):
"""
payload: Any
"""The payload."""
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.payload})"

View File

@ -30,7 +30,7 @@ class Node:
self.env = env
self.network = network
def __str__(self):
def __str__(self) -> str:
return f"{self.__class__.__name__}"
def log(self, event: str, detail: str):
@ -47,13 +47,13 @@ class Node:
"""
return self.network.send(self.ident, target, message, delay=delay)
def broadcast(self, message: Message, delay: Optional[Number]=None) -> ProcessEffect:
def broadcast(self, message: Message, include_self: bool, delay: Optional[Number]=None) -> ProcessEffect:
"""
(process) This method can be overridden to intercept messages being broadcast
by this node. The implementation in this class calls `self.network.broadcast`
with this node as the sender.
"""
return self.network.broadcast(self.ident, message, delay=delay)
return self.network.broadcast(self.ident, message, include_self, delay=delay)
def receive(self, sender: int, message: Message) -> ProcessEffect:
"""
@ -86,8 +86,14 @@ class Network:
a set of initial nodes, message propagation delay, and logger.
"""
self.env = env
"""The `simpy.Environment`."""
self.nodes = nodes or []
"""The nodes in this network."""
self.delay = delay
"""The message propagation delay."""
self._logger = logger
logger.header()
@ -166,19 +172,21 @@ class Network:
# TODO: make it take some time on the sending node.
return skip()
def broadcast(self, sender: int, message: Message, delay: Optional[Number]=None) -> ProcessEffect:
def broadcast(self, sender: int, message: Message, include_self: bool,
delay: Optional[Number]=None) -> ProcessEffect:
"""
(process) Broadcasts a message to every other node. The message
propagation delay is normally given by `self.delay`, but can be
overridden by the `delay` parameter.
(process) Broadcasts a message to every node (including ourself only when
`include_self` is set). The message propagation delay is normally given by
`self.delay`, but can be overridden by the `delay` parameter.
"""
if delay is None:
delay = self.delay
self.log(sender, "broadcast", f"to * with delay {delay:2d}: {message}")
c = "+" if include_self else "-"
self.log(sender, "broadcast", f"to {c}* with delay {delay:2d}: {message}")
# Run `convey` in a new process for each node.
for target in range(self.num_nodes()):
if target != sender:
if include_self or target != sender:
Process(self.env, self.convey(delay, sender, target, message))
# Broadcasting is currently instantaneous.

View File

@ -89,7 +89,7 @@ class SequentialNode(Node):
while True:
while len(self._mailbox) > 0:
(sender, message) = self._mailbox.popleft()
self.log("handle", f"from {sender:2d}: {message}")
self.log("handle", f"from {sender:2d}: {message}")
yield from self.handle(sender, message)
# This naive implementation is fine because we have no actual
@ -147,7 +147,7 @@ class SenderTestNode(PassiveNode):
yield Timeout(self.env, 1)
# This message is broadcast at time 4 and received at time 5.
yield from self.broadcast(PayloadMessage(4))
yield from self.broadcast(PayloadMessage(4), False)
class TestFramework(unittest.TestCase):