multi account

This commit is contained in:
Hanh 2021-06-29 15:04:12 +08:00
parent f1d948c0f6
commit cb44cb2438
8 changed files with 263 additions and 193 deletions

View File

@ -1,7 +1,7 @@
use crate::commitment::{CTree, Witness};
use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient;
use crate::lw_rpc::*;
use crate::{advance_tree, NETWORK};
use crate::{advance_tree, NETWORK, LWD_URL};
use ff::PrimeField;
use group::GroupEncoding;
use log::info;
@ -17,13 +17,9 @@ use zcash_primitives::sapling::note_encryption::try_sapling_compact_note_decrypt
use zcash_primitives::sapling::{Node, Note, PaymentAddress};
use zcash_primitives::transaction::components::sapling::CompactOutputDescription;
use zcash_primitives::zip32::ExtendedFullViewingKey;
use std::collections::HashMap;
const MAX_CHUNK: u32 = 50000;
pub const LWD_URL: &str = "https://mainnet.lightwalletd.com:9067";
// pub const LWD_URL: &str = "https://testnet.lightwalletd.com:9067";
// pub const LWD_URL: &str = "http://lwd.hanh.me:9067";
// pub const LWD_URL: &str = "https://lwdv3.zecwallet.co";
// pub const LWD_URL: &str = "http://127.0.0.1:9067";
pub async fn get_latest_height(
client: &mut CompactTxStreamerClient<Channel>,
@ -82,12 +78,18 @@ pub async fn download_chain(
}
pub struct DecryptNode {
fvks: Vec<ExtendedFullViewingKey>,
fvks: HashMap<u32, ExtendedFullViewingKey>,
}
#[derive(Eq, Hash, PartialEq, Copy, Clone)]
pub struct Nf(pub [u8; 32]);
#[derive(Copy, Clone)]
pub struct NfRef {
pub id_note: u32,
pub account: u32
}
pub struct DecryptedBlock<'a> {
pub height: u32,
pub notes: Vec<DecryptedNote>,
@ -98,6 +100,7 @@ pub struct DecryptedBlock<'a> {
#[derive(Clone)]
pub struct DecryptedNote {
pub account: u32,
pub ivk: ExtendedFullViewingKey,
pub note: Note,
pub pa: PaymentAddress,
@ -126,7 +129,7 @@ pub fn to_output_description(co: &CompactOutput) -> CompactOutputDescription {
fn decrypt_notes<'a>(
block: &'a CompactBlock,
fvks: &[ExtendedFullViewingKey],
fvks: &HashMap<u32, ExtendedFullViewingKey>,
) -> DecryptedBlock<'a> {
let height = BlockHeight::from_u32(block.height as u32);
let mut count_outputs = 0u32;
@ -140,13 +143,14 @@ fn decrypt_notes<'a>(
}
for (output_index, co) in vtx.outputs.iter().enumerate() {
for fvk in fvks.iter() {
for (&account, fvk) in fvks.iter() {
let ivk = &fvk.fvk.vk.ivk();
let od = to_output_description(co);
if let Some((note, pa)) =
try_sapling_compact_note_decryption(&NETWORK, height, ivk, &od)
{
notes.push(DecryptedNote {
account,
ivk: fvk.clone(),
note,
pa,
@ -171,7 +175,7 @@ fn decrypt_notes<'a>(
}
impl DecryptNode {
pub fn new(fvks: Vec<ExtendedFullViewingKey>) -> DecryptNode {
pub fn new(fvks: HashMap<u32, ExtendedFullViewingKey>) -> DecryptNode {
DecryptNode { fvks }
}
@ -324,12 +328,15 @@ pub async fn connect_lightwalletd() -> anyhow::Result<CompactTxStreamerClient<Ch
Ok(client)
}
pub async fn sync(ivk: &str) -> anyhow::Result<()> {
let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &ivk)
.unwrap()
.unwrap();
let decrypter = DecryptNode::new(vec![fvk]);
pub async fn sync(fvks: &HashMap<u32, String>) -> anyhow::Result<()> {
let fvks: HashMap<_, _> = fvks.iter().map(|(&account, fvk)| {
let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk)
.unwrap()
.unwrap();
(account, fvk)
}).collect();
let decrypter = DecryptNode::new(fvks);
let mut client = connect_lightwalletd().await?;
let start_height: u32 = crate::NETWORK
.activation_height(NetworkUpgrade::Sapling)
@ -361,7 +368,7 @@ pub async fn sync(ivk: &str) -> anyhow::Result<()> {
#[cfg(test)]
mod tests {
use crate::chain::LWD_URL;
use crate::LWD_URL;
#[allow(unused_imports)]
use crate::chain::{
calculate_tree_state_v1, calculate_tree_state_v2, download_chain, get_latest_height,
@ -373,6 +380,8 @@ mod tests {
use std::time::Instant;
use zcash_client_backend::encoding::decode_extended_full_viewing_key;
use zcash_primitives::consensus::{NetworkUpgrade, Parameters};
use zcash_primitives::zip32::ExtendedFullViewingKey;
use std::collections::HashMap;
#[tokio::test]
async fn test_get_latest_height() -> anyhow::Result<()> {
@ -387,11 +396,13 @@ mod tests {
dotenv::dotenv().unwrap();
let fvk = dotenv::var("FVK").unwrap();
let mut fvks: HashMap<u32, ExtendedFullViewingKey> = HashMap::new();
let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk)
.unwrap()
.unwrap();
let decrypter = DecryptNode::new(vec![fvk]);
fvks.insert(1, fvk);
let decrypter = DecryptNode::new(fvks);
let mut client = CompactTxStreamerClient::connect(LWD_URL).await?;
let start_height: u32 = crate::NETWORK
.activation_height(NetworkUpgrade::Sapling)

150
src/db.rs
View File

@ -1,4 +1,4 @@
use crate::chain::Nf;
use crate::chain::{Nf, NfRef};
use crate::{CTree, Witness};
use rusqlite::{params, Connection, OptionalExtension, NO_PARAMS};
use std::collections::HashMap;
@ -15,6 +15,7 @@ pub struct DbAdapter {
}
pub struct ReceivedNote {
pub account: u32,
pub height: u32,
pub output_index: u32,
pub diversifier: Vec<u8>,
@ -41,9 +42,10 @@ impl DbAdapter {
self.connection.execute(
"CREATE TABLE IF NOT EXISTS accounts (
id_account INTEGER PRIMARY KEY,
seed TEXT NOT NULL,
sk TEXT NOT NULL UNIQUE,
ivk TEXT NOT NULL,
name TEXT NOT NULL,
seed TEXT,
sk TEXT,
ivk TEXT NOT NULL UNIQUE,
address TEXT NOT NULL)",
NO_PARAMS,
)?;
@ -60,6 +62,7 @@ impl DbAdapter {
self.connection.execute(
"CREATE TABLE IF NOT EXISTS transactions (
id_tx INTEGER PRIMARY KEY,
account INTEGER NOT NULL,
txid BLOB NOT NULL UNIQUE,
height INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
@ -71,6 +74,7 @@ impl DbAdapter {
self.connection.execute(
"CREATE TABLE IF NOT EXISTS received_notes (
id_note INTEGER PRIMARY KEY,
account INTEGER NOT NULL,
position INTEGER NOT NULL,
tx INTEGER NOT NULL,
height INTEGER NOT NULL,
@ -97,25 +101,33 @@ impl DbAdapter {
Ok(())
}
pub fn store_account(&self, seed: &str, sk: &str, ivk: &str, address: &str) -> anyhow::Result<()> {
pub fn store_account(&self, name: &str, seed: Option<&str>, sk: Option<&str>, ivk: &str, address: &str) -> anyhow::Result<u32> {
self.connection.execute(
"INSERT INTO accounts(seed, sk, ivk, address) VALUES (?1, ?2, ?3, ?4)
"INSERT INTO accounts(name, seed, sk, ivk, address) VALUES (?1, ?2, ?3, ?4, ?5)
ON CONFLICT DO NOTHING",
params![seed, sk, ivk, address],
params![name, seed, sk, ivk, address],
)?;
Ok(())
let id_tx: u32 = self.connection.query_row(
"SELECT id_account FROM accounts WHERE sk = ?1",
params![sk],
|row| row.get(0),
)?;
Ok(id_tx)
}
pub fn has_account(&self, account: u32) -> anyhow::Result<bool> {
let r: Option<i32> = self
.connection
.query_row(
"SELECT 1 FROM accounts WHERE id_account = ?1",
params![account],
|row| row.get(0),
)
.optional()?;
Ok(r.is_some())
pub fn get_fvks(&self) -> anyhow::Result<HashMap<u32, String>> {
let mut statement = self.connection.prepare("SELECT id_account, ivk FROM accounts")?;
let rows = statement.query_map(NO_PARAMS, |row| {
let account: u32 = row.get(0)?;
let ivk: String = row.get(1)?;
Ok((account, ivk))
})?;
let mut fvks: HashMap<u32, String> = HashMap::new();
for r in rows {
let row = r?;
fvks.insert(row.0, row.1);
}
Ok(fvks)
}
pub fn trim_to_height(&mut self, height: u32) -> anyhow::Result<()> {
@ -161,16 +173,17 @@ impl DbAdapter {
pub fn store_transaction(
&self,
txid: &[u8],
account: u32,
height: u32,
timestamp: u32,
tx_index: u32,
) -> anyhow::Result<u32> {
log::debug!("+transaction");
self.connection.execute(
"INSERT INTO transactions(txid, height, timestamp, tx_index, value)
VALUES (?1, ?2, ?3, ?4, 0)
"INSERT INTO transactions(account, txid, height, timestamp, tx_index, value)
VALUES (?1, ?2, ?3, ?4, ?5, 0)
ON CONFLICT DO NOTHING",
params![txid, height, timestamp, tx_index],
params![account, txid, height, timestamp, tx_index],
)?;
let id_tx: u32 = self.connection.query_row(
"SELECT id_tx FROM transactions WHERE txid = ?1",
@ -188,9 +201,9 @@ impl DbAdapter {
position: usize,
) -> anyhow::Result<u32> {
log::debug!("+received_note {}", id_tx);
self.connection.execute("INSERT INTO received_notes(tx, height, position, output_index, diversifier, value, rcm, nf, spent)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
ON CONFLICT DO NOTHING", params![id_tx, note.height, position as u32, note.output_index, note.diversifier, note.value as i64, note.rcm, note.nf, note.spent])?;
self.connection.execute("INSERT INTO received_notes(account, tx, height, position, output_index, diversifier, value, rcm, nf, spent)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
ON CONFLICT DO NOTHING", params![note.account, id_tx, note.height, position as u32, note.output_index, note.diversifier, note.value as i64, note.rcm, note.nf, note.spent])?;
let id_note: u32 = self.connection.query_row(
"SELECT id_note FROM received_notes WHERE tx = ?1 AND output_index = ?2",
params![id_tx, note.output_index],
@ -226,28 +239,32 @@ impl DbAdapter {
Ok(())
}
pub fn get_received_note_value(&self, nf: &Nf) -> anyhow::Result<i64> {
let value: i64 = self.connection.query_row(
"SELECT value FROM received_notes WHERE nf = ?1",
pub fn get_received_note_value(&self, nf: &Nf) -> anyhow::Result<(u32, i64)> {
let (account, value) = self.connection.query_row(
"SELECT account, value FROM received_notes WHERE nf = ?1",
params![nf.0.to_vec()],
|row| row.get(0),
|row| {
let account: u32 = row.get(0)?;
let value: i64 = row.get(1)?;
Ok((account, value))
},
)?;
Ok(value)
Ok((account, value))
}
pub fn get_balance(&self) -> anyhow::Result<u64> {
pub fn get_balance(&self, account: u32) -> anyhow::Result<u64> {
let balance: Option<i64> = self.connection.query_row(
"SELECT SUM(value) FROM received_notes WHERE spent IS NULL OR spent = 0",
NO_PARAMS,
"SELECT SUM(value) FROM received_notes WHERE (spent IS NULL OR spent = 0) AND account = ?1",
params![account],
|row| row.get(0),
)?;
Ok(balance.unwrap_or(0) as u64)
}
pub fn get_spendable_balance(&self, anchor_height: u32) -> anyhow::Result<u64> {
pub fn get_spendable_balance(&self, account: u32, anchor_height: u32) -> anyhow::Result<u64> {
let balance: Option<i64> = self.connection.query_row(
"SELECT SUM(value) FROM received_notes WHERE spent IS NULL AND height <= ?1",
params![anchor_height],
"SELECT SUM(value) FROM received_notes WHERE spent IS NULL AND height <= ?1 AND account = ?2",
params![anchor_height, account],
|row| row.get(0),
)?;
Ok(balance.unwrap_or(0) as u64)
@ -293,7 +310,7 @@ impl DbAdapter {
Some((height, tree)) => {
let tree = CTree::read(&*tree)?;
let mut statement = self.connection.prepare(
"SELECT id_note, witness FROM sapling_witnesses w, received_notes n WHERE w.height = ?1 AND w.note = n.id_note AND (n.spent IS NULL OR n.spent = 0)")?;
"SELECT id_note, witness FROM sapling_witnesses w, received_notes n WHERE w.height = ?1 AND w.note = n.id_note AND (n.spent IS NULL OR n.spent = 0)")?;
let ws = statement.query_map(params![height], |row| {
let id_note: u32 = row.get(0)?;
let witness: Vec<u8> = row.get(1)?;
@ -309,18 +326,23 @@ impl DbAdapter {
})
}
pub fn get_nullifiers(&self) -> anyhow::Result<HashMap<Nf, u32>> {
pub fn get_nullifiers(&self) -> anyhow::Result<HashMap<Nf, NfRef>> {
let mut statement = self
.connection
.prepare("SELECT id_note, nf FROM received_notes WHERE spent IS NULL OR spent = 0")?;
.prepare("SELECT id_note, account, nf FROM received_notes WHERE spent IS NULL OR spent = 0")?;
let nfs_res = statement.query_map(NO_PARAMS, |row| {
let id_note: u32 = row.get(0)?;
let nf_vec: Vec<u8> = row.get(1)?;
let account: u32 = row.get(1)?;
let nf_vec: Vec<u8> = row.get(2)?;
let mut nf = [0u8; 32];
nf.clone_from_slice(&nf_vec);
Ok((id_note, nf))
let nf_ref = NfRef {
id_note,
account
};
Ok((nf_ref, nf))
})?;
let mut nfs: HashMap<Nf, u32> = HashMap::new();
let mut nfs: HashMap<Nf, NfRef> = HashMap::new();
for n in nfs_res {
let n = n?;
nfs.insert(Nf(n.1), n.0);
@ -329,11 +351,11 @@ impl DbAdapter {
Ok(nfs)
}
pub fn get_nullifier_amounts(&self) -> anyhow::Result<HashMap<Vec<u8>, u64>> {
pub fn get_nullifier_amounts(&self, account: u32) -> anyhow::Result<HashMap<Vec<u8>, u64>> {
let mut statement = self
.connection
.prepare("SELECT value, nf FROM received_notes WHERE spent IS NULL OR spent = 0")?;
let nfs_res = statement.query_map(NO_PARAMS, |row| {
.prepare("SELECT value, nf FROM received_notes WHERE account = ?1 AND (spent IS NULL OR spent = 0)")?;
let nfs_res = statement.query_map(params![account], |row| {
let amount: i64 = row.get(0)?;
let nf: Vec<u8> = row.get(1)?;
Ok((amount, nf))
@ -349,15 +371,16 @@ impl DbAdapter {
pub fn get_spendable_notes(
&self,
account: u32,
anchor_height: u32,
fvk: &ExtendedFullViewingKey,
) -> anyhow::Result<Vec<SpendableNote>> {
let mut statement = self.connection.prepare(
"SELECT id_note, diversifier, value, rcm, witness FROM received_notes r, sapling_witnesses w WHERE spent IS NULL
"SELECT id_note, diversifier, value, rcm, witness FROM received_notes r, sapling_witnesses w WHERE spent IS NULL AND account = ?2
AND w.height = (
SELECT MAX(height) FROM sapling_witnesses WHERE height <= ?1
) AND r.id_note = w.note")?;
let notes = statement.query_map(params![anchor_height], |row| {
let notes = statement.query_map(params![anchor_height, account], |row| {
let id_note: u32 = row.get(0)?;
let diversifier: Vec<u8> = row.get(1)?;
@ -401,32 +424,34 @@ impl DbAdapter {
Ok(())
}
pub fn get_seed(&self, account: u32) -> anyhow::Result<String> {
log::debug!("+get_seed");
let ivk = self.connection.query_row(
"SELECT seed FROM accounts WHERE id_account = ?1",
pub fn get_backup(&self, account: u32) -> anyhow::Result<(Option<String>, Option<String>, String)> {
log::debug!("+get_backup");
let (seed, sk, ivk) = self.connection.query_row(
"SELECT seed, sk, ivk FROM accounts WHERE id_account = ?1",
params![account],
|row| {
let seed: Option<String> = row.get(0)?;
let sk: Option<String> = row.get(0)?;
let ivk: String = row.get(0)?;
Ok(ivk)
Ok((seed, sk, ivk))
},
)?;
log::debug!("-get_seed");
Ok(ivk)
log::debug!("-get_backup");
Ok((seed, sk, ivk))
}
pub fn get_sk(&self, account: u32) -> anyhow::Result<String> {
log::debug!("+get_sk");
let ivk = self.connection.query_row(
log::info!("+get_sk");
let sk = self.connection.query_row(
"SELECT sk FROM accounts WHERE id_account = ?1",
params![account],
|row| {
let ivk: String = row.get(0)?;
Ok(ivk)
let sk: String = row.get(0)?;
Ok(sk)
},
)?;
log::debug!("-get_sk");
Ok(ivk)
log::info!("-get_sk");
Ok(sk)
}
pub fn get_ivk(&self, account: u32) -> anyhow::Result<String> {
@ -456,9 +481,10 @@ mod tests {
db.trim_to_height(0).unwrap();
db.store_block(1, &[0u8; 32], 0, &CTree::new()).unwrap();
let id_tx = db.store_transaction(&[0; 32], 1, 0, 20).unwrap();
let id_tx = db.store_transaction(&[0; 32], 1, 1, 0, 20).unwrap();
db.store_received_note(
&ReceivedNote {
account: 1,
height: 1,
output_index: 0,
diversifier: vec![],
@ -470,7 +496,7 @@ mod tests {
id_tx,
5,
)
.unwrap();
.unwrap();
let witness = Witness {
position: 10,
id_note: 0,
@ -485,7 +511,7 @@ mod tests {
#[test]
fn test_balance() {
let db = DbAdapter::new(DEFAULT_DB_PATH).unwrap();
let balance = db.get_balance().unwrap();
let balance = db.get_balance(1).unwrap();
println!("{}", balance);
}
}

View File

@ -1,5 +1,4 @@
use crate::NETWORK;
use anyhow::anyhow;
use bip39::{Language, Mnemonic, Seed};
use zcash_client_backend::encoding::{
decode_extended_full_viewing_key, decode_extended_spending_key,
@ -8,8 +7,40 @@ use zcash_client_backend::encoding::{
use zcash_primitives::consensus::Parameters;
use zcash_primitives::zip32::{ChildIndex, ExtendedFullViewingKey, ExtendedSpendingKey};
pub fn get_secret_key(seed: &str) -> anyhow::Result<String> {
let mnemonic = Mnemonic::from_phrase(&seed, Language::English)?;
pub fn decode_key(key: &str) -> anyhow::Result<(Option<String>, Option<String>, String, String)> {
let res =
if let Ok(mnemonic) = Mnemonic::from_phrase(&key, Language::English) {
let (sk, ivk, pa) = derive_secret_key(&mnemonic)?;
Ok((Some(key.to_string()), Some(sk), ivk, pa))
}
else if let Ok(Some(sk)) = decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &key) {
let (ivk, pa) = derive_viewing_key(&sk)?;
Ok((None, Some(key.to_string()), ivk, pa))
}
else if let Ok(Some(fvk)) = decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &key) {
let pa = derive_address(&fvk)?;
Ok((None, None, key.to_string(), pa))
}
else {
Err(anyhow::anyhow!("Not a valid key"))
};
res
}
pub fn is_valid_key(key: &str) -> bool {
if Mnemonic::from_phrase(&key, Language::English).is_ok() {
return true;
}
if decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &key).is_ok() {
return true;
}
if decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &key).is_ok() {
return true;
}
false
}
pub fn derive_secret_key(mnemonic: &Mnemonic) -> anyhow::Result<(String, String, String)> {
let seed = Seed::new(&mnemonic, "");
let master = ExtendedSpendingKey::master(seed.as_bytes());
let path = [
@ -18,28 +49,21 @@ pub fn get_secret_key(seed: &str) -> anyhow::Result<String> {
ChildIndex::Hardened(0),
];
let extsk = ExtendedSpendingKey::from_path(&master, &path);
let spending_key =
let sk =
encode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &extsk);
Ok(spending_key)
let (fvk, pa) = derive_viewing_key(&extsk)?;
Ok((sk, fvk, pa))
}
pub fn get_viewing_key(secret_key: &str) -> anyhow::Result<String> {
let extsk =
decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), secret_key)?
.ok_or(anyhow!("Invalid Secret Key"))?;
let fvk = ExtendedFullViewingKey::from(&extsk);
let viewing_key =
encode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk);
Ok(viewing_key)
pub fn derive_viewing_key(extsk: &ExtendedSpendingKey) -> anyhow::Result<(String, String)> {
let fvk = ExtendedFullViewingKey::from(extsk);
let pa = derive_address(&fvk)?;
let fvk = encode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk);
Ok((fvk, pa))
}
pub fn get_address(viewing_key: &str) -> anyhow::Result<String> {
let fvk = decode_extended_full_viewing_key(
NETWORK.hrp_sapling_extended_full_viewing_key(),
&viewing_key,
)?
.ok_or(anyhow!("Invalid Viewing Key"))?;
pub fn derive_address(fvk: &ExtendedFullViewingKey) -> anyhow::Result<String> {
let (_, payment_address) = fvk.default_address().unwrap();
let address = encode_payment_address(NETWORK.hrp_sapling_payment_address(), &payment_address);
Ok(address)

View File

@ -3,7 +3,16 @@ use zcash_primitives::consensus::Network;
#[path = "generated/cash.z.wallet.sdk.rpc.rs"]
pub mod lw_rpc;
pub const NETWORK: Network = Network::MainNetwork;
pub const NETWORK: Network = Network::TestNetwork;
// Mainnet
// pub const LWD_URL: &str = "https://mainnet.lightwalletd.com:9067";
// pub const LWD_URL: &str = "https://lwdv3.zecwallet.co";
// Testnet
pub const LWD_URL: &str = "https://testnet.lightwalletd.com:9067";
// pub const LWD_URL: &str = "http://lwd.hanh.me:9067";
// pub const LWD_URL: &str = "http://127.0.0.1:9067";
mod builder;
mod chain;
@ -18,14 +27,14 @@ mod wallet;
pub use crate::builder::advance_tree;
pub use crate::chain::{
calculate_tree_state_v2, connect_lightwalletd, download_chain, get_latest_height, sync,
DecryptNode, LWD_URL, ChainError
DecryptNode, ChainError
};
pub use crate::commitment::{CTree, Witness};
pub use crate::db::DbAdapter;
pub use crate::key::{get_address, get_secret_key, get_viewing_key};
pub use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient;
pub use crate::lw_rpc::*;
pub use crate::mempool::MemPool;
pub use crate::scan::{latest_height, scan_all, sync_async};
pub use crate::wallet::{Wallet, WalletBalance, DEFAULT_ACCOUNT};
pub use crate::wallet::{Wallet, WalletBalance};
pub use crate::print::*;
pub use crate::key::is_valid_key;

View File

@ -1,7 +1,7 @@
use bip39::{Language, Mnemonic};
use rand::rngs::OsRng;
use rand::RngCore;
use sync::{DbAdapter, Wallet, DEFAULT_ACCOUNT, ChainError, Witness, print_witness2};
use sync::{DbAdapter, Wallet, ChainError, Witness, print_witness2};
use rusqlite::NO_PARAMS;
const DB_NAME: &str = "zec.db";
@ -23,18 +23,21 @@ async fn test() -> anyhow::Result<()> {
log::info!("Height = {}", height);
};
let wallet = Wallet::new(DB_NAME);
wallet.new_account_with_seed(&seed).unwrap();
let res = wallet.sync(DEFAULT_ACCOUNT, progress).await;
wallet.new_account_with_key("test", &seed).unwrap();
let res = wallet.sync(progress).await;
if let Err(err) = res {
if let Some(_) = err.downcast_ref::<ChainError>() {
println!("REORG");
}
else {
panic!(err);
}
}
// let tx_id = wallet
// .send_payment(DEFAULT_ACCOUNT, &address, 50000)
// .await
// .unwrap();
// println!("TXID = {}", tx_id);
let tx_id = wallet
.send_payment(1, &address, 50000)
.await
.unwrap();
println!("TXID = {}", tx_id);
Ok(())
}
@ -57,7 +60,7 @@ fn test_rewind() {
#[allow(dead_code)]
fn test_get_balance() {
let db = DbAdapter::new(DB_NAME).unwrap();
let balance = db.get_balance().unwrap();
let balance = db.get_balance(1).unwrap();
println!("Balance = {}", (balance as f64) / 100_000_000.0);
}
@ -77,6 +80,7 @@ fn test_invalid_witness() {
print_witness2(&w);
}
#[allow(dead_code)]
fn w() {
let db = DbAdapter::new("zec.db").unwrap();
// let w_b: Vec<u8> = db.connection.query_row("SELECT witness FROM sapling_witnesses WHERE note = 66 AND height = 1466097", NO_PARAMS, |row| row.get(0)).unwrap();

View File

@ -21,6 +21,7 @@ struct MemPoolTransacton {
pub struct MemPool {
db_path: String,
account: u32,
ivk: Option<SaplingIvk>,
height: BlockHeight,
transactions: HashMap<Vec<u8>, MemPoolTransacton>,
@ -32,6 +33,7 @@ impl MemPool {
pub fn new(db_path: &str) -> MemPool {
MemPool {
db_path: db_path.to_string(),
account: 0,
ivk: None,
height: BlockHeight::from(0),
transactions: HashMap::new(),
@ -43,11 +45,13 @@ impl MemPool {
pub fn set_account(&mut self, account: u32) -> anyhow::Result<()> {
let db = DbAdapter::new(&self.db_path)?;
let ivk = db.get_ivk(account)?;
self.account = account;
self.set_ivk(&ivk);
self.clear()?;
Ok(())
}
pub fn set_ivk(&mut self, ivk: &str) {
fn set_ivk(&mut self, ivk: &str) {
let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &ivk)
.unwrap()
@ -56,30 +60,30 @@ impl MemPool {
self.ivk = Some(ivk);
}
pub async fn scan(&mut self) -> anyhow::Result<()> {
pub async fn scan(&mut self) -> anyhow::Result<i64> {
if self.ivk.is_some() {
let ivk = self.ivk.as_ref().unwrap().clone();
let mut client = connect_lightwalletd().await?;
let height = BlockHeight::from(get_latest_height(&mut client).await?);
if self.height != height {
// New blocks invalidate the mempool
let db = DbAdapter::new(&self.db_path)?;
self.clear(&db)?;
self.clear()?;
}
self.height = height;
self.update(&mut client, &ivk).await?;
}
Ok(())
Ok(self.balance)
}
pub fn get_unconfirmed_balance(&self) -> i64 {
self.balance
}
fn clear(&mut self, db: &DbAdapter) -> anyhow::Result<()> {
fn clear(&mut self) -> anyhow::Result<()> {
let db = DbAdapter::new(&self.db_path)?;
self.height = BlockHeight::from_u32(0);
self.nfs = db.get_nullifier_amounts()?;
self.nfs = db.get_nullifier_amounts(self.account)?;
self.transactions.clear();
self.balance = 0;
Ok(())

View File

@ -1,5 +1,5 @@
use crate::builder::BlockProcessor;
use crate::chain::Nf;
use crate::chain::{Nf, NfRef};
use crate::db::{DbAdapter, ReceivedNote};
use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient;
use crate::{
@ -18,9 +18,12 @@ use zcash_primitives::consensus::{NetworkUpgrade, Parameters};
use zcash_primitives::sapling::Node;
use zcash_primitives::zip32::ExtendedFullViewingKey;
use std::panic;
use std::collections::HashMap;
pub async fn scan_all(fvks: &[ExtendedFullViewingKey]) -> anyhow::Result<()> {
let decrypter = DecryptNode::new(fvks.to_vec());
let fvks: HashMap<_, _> = fvks.iter().enumerate().map(|(i, fvk)|
(i as u32, fvk.clone())).collect();
let decrypter = DecryptNode::new(fvks);
let total_start = Instant::now();
let mut client = CompactTxStreamerClient::connect(LWD_URL).await?;
@ -68,27 +71,33 @@ impl std::fmt::Debug for Blocks {
pub type ProgressCallback = Arc<Mutex<dyn Fn(u32) + Send>>;
pub async fn sync_async(
ivk: &str,
chunk_size: u32,
db_path: &str,
target_height_offset: u32,
progress_callback: ProgressCallback,
) -> anyhow::Result<()> {
let db_path = db_path.to_string();
let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &ivk)?
.ok_or_else(|| anyhow::anyhow!("Invalid key"))?;
let decrypter = DecryptNode::new(vec![fvk]);
let mut client = connect_lightwalletd().await?;
let (start_height, mut prev_hash) = {
let (start_height, mut prev_hash, fvks) = {
let db = DbAdapter::new(&db_path)?;
let height = db.get_db_height()?;
(height, db.get_db_hash(height)?)
let hash = db.get_db_hash(height)?;
let fvks = db.get_fvks()?;
(height, hash, fvks)
};
let end_height = get_latest_height(&mut client).await?;
let end_height = (end_height - target_height_offset).max(start_height);
let fvks: HashMap<_, _> = fvks.iter().map(|(&account, fvk)| {
let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk)
.unwrap()
.unwrap();
(account, fvk)
}).collect();
let decrypter = DecryptNode::new(fvks);
let (downloader_tx, mut download_rx) = mpsc::channel::<Range<u32>>(2);
let (processor_tx, mut processor_rx) = mpsc::channel::<Blocks>(1);
@ -133,9 +142,9 @@ pub async fn sync_async(
for b in dec_blocks.iter() {
let mut my_nfs: Vec<Nf> = vec![];
for nf in b.spends.iter() {
if let Some(&id) = nfs.get(nf) {
log::info!("NF FOUND {} {}", id, b.height);
db.mark_spent(id, b.height)?;
if let Some(&nf_ref) = nfs.get(nf) {
log::info!("NF FOUND {} {}", nf_ref.id_note, b.height);
db.mark_spent(nf_ref.id_note, b.height)?;
my_nfs.push(*nf);
}
}
@ -152,12 +161,14 @@ pub async fn sync_async(
let id_tx = db.store_transaction(
&n.txid,
n.account,
n.height,
b.compact_block.time,
n.tx_index as u32,
)?;
let id_note = db.store_received_note(
&ReceivedNote {
account: n.account,
height: n.height,
output_index: n.output_index as u32,
diversifier: n.pa.diversifier().0.to_vec(),
@ -170,7 +181,7 @@ pub async fn sync_async(
n.position_in_block,
)?;
db.add_value(id_tx, note.value as i64)?;
nfs.insert(Nf(nf.0), id_note);
nfs.insert(Nf(nf.0), NfRef { id_note, account: n.account });
let w = Witness::new(p as usize, id_note, Some(n.clone()));
witnesses.push(w);
@ -183,10 +194,11 @@ pub async fn sync_async(
nf.copy_from_slice(&cs.nf);
let nf = Nf(nf);
if my_nfs.contains(&nf) {
let note_value = db.get_received_note_value(&nf)?;
let (account, note_value) = db.get_received_note_value(&nf)?;
let txid = &*tx.hash;
let id_tx = db.store_transaction(
txid,
account,
b.height,
b.compact_block.time,
tx_index as u32,

View File

@ -1,7 +1,6 @@
use crate::chain::send_transaction;
use crate::mempool::MemPool;
use crate::scan::ProgressCallback;
use crate::{connect_lightwalletd, get_address, get_latest_height, get_secret_key, get_viewing_key, DbAdapter, NETWORK, BlockId, CTree};
use crate::{connect_lightwalletd, get_latest_height, DbAdapter, NETWORK, BlockId, CTree};
use anyhow::Context;
use bip39::{Language, Mnemonic};
use rand::prelude::SliceRandom;
@ -11,7 +10,7 @@ use std::sync::Arc;
use tokio::sync::Mutex;
use zcash_client_backend::address::RecipientAddress;
use zcash_client_backend::data_api::wallet::ANCHOR_OFFSET;
use zcash_client_backend::encoding::{decode_extended_spending_key, decode_payment_address};
use zcash_client_backend::encoding::decode_extended_spending_key;
use zcash_params::{OUTPUT_PARAMS, SPEND_PARAMS};
use zcash_primitives::consensus::{BlockHeight, BranchId, Parameters};
use zcash_primitives::transaction::builder::Builder;
@ -20,8 +19,8 @@ use zcash_primitives::transaction::components::Amount;
use zcash_primitives::zip32::ExtendedFullViewingKey;
use zcash_proofs::prover::LocalTxProver;
use tonic::Request;
use crate::key::{is_valid_key, decode_key};
pub const DEFAULT_ACCOUNT: u32 = 1;
const DEFAULT_CHUNK_SIZE: u32 = 100_000;
pub struct Wallet {
@ -59,48 +58,48 @@ impl Wallet {
}
}
pub fn valid_seed(seed: &str) -> bool {
get_secret_key(&seed).is_ok()
pub fn valid_key(key: &str) -> bool {
is_valid_key(key)
}
pub fn valid_address(address: &str) -> bool {
decode_payment_address(NETWORK.hrp_sapling_payment_address(), address).is_ok()
let recipient = RecipientAddress::decode(&NETWORK, address);
recipient.is_some()
}
pub fn new_seed(&self) -> anyhow::Result<()> {
let mut entropy = [0u8; 32];
OsRng.fill_bytes(&mut entropy);
let mnemonic = Mnemonic::from_entropy(&entropy, Language::English)?;
let seed = mnemonic.phrase();
self.new_account_with_seed(seed)?;
Ok(())
pub fn new_account(&self, name: &str, data: &str) -> anyhow::Result<u32> {
if data.is_empty() {
let mut entropy = [0u8; 32];
OsRng.fill_bytes(&mut entropy);
let mnemonic = Mnemonic::from_entropy(&entropy, Language::English)?;
let seed = mnemonic.phrase();
self.new_account_with_key(name, seed)
}
else {
self.new_account_with_key(name, data)
}
}
pub fn get_seed(&self, account: u32) -> anyhow::Result<String> {
self.db.get_seed(account)
pub fn get_backup(&self, account: u32) -> anyhow::Result<String> {
let (seed, sk, ivk) = self.db.get_backup(account)?;
if let Some(seed) = seed { return Ok(seed); }
if let Some(sk) = sk { return Ok(sk); }
Ok(ivk)
}
pub fn has_account(&self, account: u32) -> anyhow::Result<bool> {
self.db.has_account(account)
}
pub fn new_account_with_seed(&self, seed: &str) -> anyhow::Result<()> {
let sk = get_secret_key(&seed).unwrap();
let vk = get_viewing_key(&sk).unwrap();
let pa = get_address(&vk).unwrap();
self.db.store_account(seed, &sk, &vk, &pa)?;
Ok(())
pub fn new_account_with_key(&self, name: &str, key: &str) -> anyhow::Result<u32> {
let (seed, sk, ivk, pa) = decode_key(key)?;
let account = self.db.store_account(name, seed.as_deref(), sk.as_deref(), &ivk, &pa)?;
Ok(account)
}
async fn scan_async(
ivk: &str,
db_path: &str,
chunk_size: u32,
target_height_offset: u32,
progress_callback: ProgressCallback,
) -> anyhow::Result<()> {
crate::scan::sync_async(
ivk,
chunk_size,
db_path,
target_height_offset,
@ -118,22 +117,19 @@ impl Wallet {
// Not a method in order to avoid locking the instance
pub async fn sync_ex(
db_path: &str,
ivk: &str,
progress_callback: impl Fn(u32) + Send + 'static,
) -> anyhow::Result<()> {
let cb = Arc::new(Mutex::new(progress_callback));
Self::scan_async(&ivk, db_path, DEFAULT_CHUNK_SIZE, 10, cb.clone()).await?;
Self::scan_async(&ivk, db_path, DEFAULT_CHUNK_SIZE, 0, cb.clone()).await?;
Self::scan_async(db_path, DEFAULT_CHUNK_SIZE, 10, cb.clone()).await?;
Self::scan_async(db_path, DEFAULT_CHUNK_SIZE, 0, cb.clone()).await?;
Ok(())
}
pub async fn sync(
&self,
account: u32,
progress_callback: impl Fn(u32) + Send + 'static,
) -> anyhow::Result<()> {
let ivk = self.get_ivk(account)?;
Self::sync_ex(&self.db_path, &ivk, progress_callback).await
Self::sync_ex(&self.db_path, progress_callback).await
}
pub async fn skip_to_last_height(&self) -> anyhow::Result<()> {
@ -155,20 +151,6 @@ impl Wallet {
self.db.trim_to_height(height)
}
pub async fn get_balance(&self, mempool: &MemPool) -> anyhow::Result<WalletBalance> {
let last_height = Self::get_latest_height().await?;
let anchor_height = last_height - ANCHOR_OFFSET;
let confirmed = self.db.get_balance()?;
let unconfirmed = mempool.get_unconfirmed_balance();
let spendable = self.db.get_spendable_balance(anchor_height)?;
Ok(WalletBalance {
confirmed,
unconfirmed,
spendable,
})
}
pub async fn send_payment(
&self,
account: u32,
@ -193,7 +175,7 @@ impl Wallet {
.ok_or_else(|| anyhow::anyhow!("No spendable notes"))?;
let anchor_height = anchor_height.min(last_height - ANCHOR_OFFSET);
log::info!("Anchor = {}", anchor_height);
let mut notes = self.db.get_spendable_notes(anchor_height, &extfvk)?;
let mut notes = self.db.get_spendable_notes(account, anchor_height, &extfvk)?;
notes.shuffle(&mut OsRng);
log::info!("Spendable notes = {}", notes.len());
@ -260,7 +242,8 @@ impl Wallet {
#[cfg(test)]
mod tests {
use crate::wallet::Wallet;
use crate::{get_address, get_secret_key, get_viewing_key};
use crate::key::derive_secret_key;
use bip39::{Mnemonic, Language};
#[tokio::test]
async fn test_wallet_seed() {
@ -269,7 +252,7 @@ mod tests {
let seed = dotenv::var("SEED").unwrap();
let wallet = Wallet::new("zec.db");
wallet.new_account_with_seed(&seed).unwrap();
wallet.new_account_with_key("test", &seed).unwrap();
}
#[tokio::test]
@ -278,14 +261,11 @@ mod tests {
env_logger::init();
let seed = dotenv::var("SEED").unwrap();
let sk = get_secret_key(&seed).unwrap();
let vk = get_viewing_key(&sk).unwrap();
println!("{}", vk);
let pa = get_address(&vk).unwrap();
println!("{}", pa);
let wallet = Wallet::new("zec.db");
let tx_id = wallet.send_payment(1, &pa, 1000).await.unwrap();
println!("TXID = {}", tx_id);
let (sk, vk, pa) = derive_secret_key(&Mnemonic::from_phrase(&seed, Language::English).unwrap()).unwrap();
println!("{} {} {}", sk, vk, pa);
// let wallet = Wallet::new("zec.db");
//
// let tx_id = wallet.send_payment(1, &pa, 1000).await.unwrap();
// println!("TXID = {}", tx_id);
}
}