Wallet tests

This commit is contained in:
Mariano Sorgente 2020-04-25 20:13:57 +09:00
parent 2705a4e75c
commit 36caabb372
No known key found for this signature in database
GPG Key ID: 0F866338C369278C
6 changed files with 94 additions and 37 deletions

View File

@ -87,6 +87,7 @@ def generate(args, parser):
key_config["wallet_sk"] = bytes(wallet_sk).hex()
key_config["wallet_target"] = wallet_target.hex()
save_config(root_path, keys_yaml, key_config)
print("WALLET TARGET:", wallet_target.hex())
if args.harvester:
# Replaces the harvester's sk seed. Used to generate plot private keys, which are
# used to sign farmed blocks.
@ -109,4 +110,6 @@ def generate(args, parser):
# Compute a new pool target and save it to the config
assert "wallet_target" in key_config
key_config["pool_target"] = key_config["wallet_target"]
print("POOL TARGET:", key_config["pool_target"])
print("Root path", root_path)
save_config(root_path, keys_yaml, key_config)

View File

@ -222,6 +222,7 @@ class FullNode:
yield OutboundMessage(NodeType.INTRODUCER, msg, Delivery.RESPOND)
while not self._shut_down:
# If we are still connected to introducer, disconnect
for connection in self.server.global_connections.get_connections():
if connection.connection_type == NodeType.INTRODUCER:
self.server.global_connections.close(connection)

View File

@ -167,6 +167,10 @@ class WalletNode:
yield OutboundMessage(NodeType.INTRODUCER, msg, Delivery.RESPOND)
while not self._shut_down:
for connection in self.server.global_connections.get_connections():
# If we are still connected to introducer, disconnect
if connection.connection_type == NodeType.INTRODUCER:
self.server.global_connections.close(connection)
if self._num_needed_peers():
if not await self.server.start_client(
introducer_peerinfo, on_connect

View File

@ -1146,7 +1146,7 @@ class WalletStateManager:
# Get all unspent coins
my_coin_records_lca: Set[
WalletCoinRecord
] = await self.wallet_store.get_coin_records_by_spent(False, uint32(fork_h + 1))
] = await self.wallet_store.get_unspent_coins_at_height(uint32(fork_h))
# Filter coins up to and including fork point
unspent_coin_names: Set[bytes32] = set()
@ -1267,9 +1267,6 @@ class WalletStateManager:
async def get_all_wallets(self) -> List[WalletInfo]:
return await self.user_store.get_all_wallets()
async def get_coin_records_by_spent(self, spent: bool):
return await self.wallet_store.get_coin_records_by_spent(spent)
async def get_spendable_coins_for_wallet(
self, wallet_id: int
) -> Set[WalletCoinRecord]:

View File

@ -143,19 +143,22 @@ class WalletStore:
)
return None
async def get_coin_records_by_spent(
self, spent: bool, spend_before_height: Optional[uint32] = None
async def get_unspent_coins_at_height(
self, height: Optional[uint32] = None
) -> Set[WalletCoinRecord]:
""" Returns set of CoinRecords that have not been spent yet. """
coins = set()
if spend_before_height:
if height is not None:
cursor_test = await self.db_connection.execute(
"SELECT * from coin_record",
)
cursor = await self.db_connection.execute(
"SELECT * from coin_record WHERE spent=? OR spent_index>=?",
(int(spent), spend_before_height),
"SELECT * from coin_record WHERE (spent=? OR spent_index>?) AND confirmed_index<=?",
(0, height, height),
)
else:
cursor = await self.db_connection.execute(
"SELECT * from coin_record WHERE spent=?", (int(spent),)
"SELECT * from coin_record WHERE spent=?", (0,)
)
rows = await cursor.fetchall()
await cursor.close()
@ -243,7 +246,7 @@ class WalletStore:
result: Dict[bytes32, Coin] = {}
unspent_coin_records: Set[
WalletCoinRecord
] = await self.get_coin_records_by_spent(False)
] = await self.get_unspent_coins_at_height()
for record in unspent_coin_records:
result[record.name()] = record.coin

View File

@ -1,16 +1,21 @@
# import asyncio
# from secrets import token_bytes
# from pathlib import Path
# from typing import Any, Dict
# import sqlite3
# import random
import asyncio
from secrets import token_bytes
from pathlib import Path
from typing import Any, Dict
from secrets import token_bytes
import aiosqlite
import random
# import pytest
# from src.full_node.store import FullNodeStore
# from src.types.full_block import FullBlock
# from src.types.sized_bytes import bytes32
# from src.util.ints import uint32, uint64
# from tests.block_tools import BlockTools
import pytest
from src.full_node.store import FullNodeStore
from src.types.full_block import FullBlock
from src.types.sized_bytes import bytes32
from src.util.ints import uint32, uint64
from tests.block_tools import BlockTools
from src.wallet.wallet_store import WalletStore
from src.wallet.wallet_coin_record import WalletCoinRecord
from src.wallet.util.wallet_types import WalletType
from src.types.coin import Coin
# bt = BlockTools()
@ -27,21 +32,65 @@
# )
# @pytest.fixture(scope="module")
# def event_loop():
# loop = asyncio.get_event_loop()
# yield loop
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
# class TestWalletStore:
# @pytest.mark.asyncio
# async def test_store(self):
# blocks = bt.get_consecutive_blocks(test_constants, 9, [], 9, b"0")
# db_filename = Path("blockchain_wallet_store_test.db")
class TestWalletStore:
@pytest.mark.asyncio
async def test_store(self):
db_filename = Path("blockchain_wallet_store_test.db")
# if db_filename.exists():
# db_filename.unlink()
if db_filename.exists():
db_filename.unlink()
# db = await FullNodeStore.create(db_filename)
# try:
# await db._clear_database()
db_connection = await aiosqlite.connect(db_filename)
store = await WalletStore.create(db_connection)
try:
coin_1 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
coin_2 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
coin_3 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
coin_4 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
record_replaced = WalletCoinRecord(coin_1, uint32(8), uint32(0), False, True, WalletType.STANDARD_WALLET, 0)
record_1 = WalletCoinRecord(coin_1, uint32(4), uint32(0), False, True, WalletType.STANDARD_WALLET, 0)
record_2 = WalletCoinRecord(coin_2, uint32(5), uint32(0), False, True, WalletType.STANDARD_WALLET, 0)
record_3 = WalletCoinRecord(coin_3, uint32(5), uint32(10), True, False, WalletType.STANDARD_WALLET, 0)
record_4 = WalletCoinRecord(coin_4, uint32(5), uint32(15), True, False, WalletType.STANDARD_WALLET, 0)
# Test add (replace) and get
assert (await store.get_coin_record(coin_1.name()) is None)
await store.add_coin_record(record_replaced)
await store.add_coin_record(record_1)
await store.add_coin_record(record_2)
await store.add_coin_record(record_3)
await store.add_coin_record(record_4)
assert (await store.get_coin_record(coin_1.name()) == record_1)
# Test persistance
await db_connection.close()
db_connection = await aiosqlite.connect(db_filename)
store = await WalletStore.create(db_connection)
assert (await store.get_coin_record(coin_1.name()) == record_1)
# Test set spent
await store.set_spent(coin_1.name(), uint32(12))
assert (await store.get_coin_record(coin_1.name())).spent
assert ((await store.get_coin_record(coin_1.name())).spent_block_index == 12)
# No coins at height 3
assert len(await store.get_unspent_coins_at_height(3)) == 0
assert len(await store.get_unspent_coins_at_height(4)) == 1
assert len(await store.get_unspent_coins_at_height(5)) == 4
assert len(await store.get_unspent_coins_at_height(11)) == 3
assert len(await store.get_unspent_coins_at_height(12)) == 2
assert len(await store.get_unspent_coins_at_height(15)) == 1
assert len(await store.get_unspent_coins_at_height(16)) == 1
assert len(await store.get_unspent_coins_at_height()) == 1
except:
await db_connection.close()
raise
await db_connection.close()