multi account

This commit is contained in:
Hanh 2021-06-29 15:04:12 +08:00
parent f1d948c0f6
commit cb44cb2438
8 changed files with 263 additions and 193 deletions

View File

@ -1,7 +1,7 @@
use crate::commitment::{CTree, Witness}; use crate::commitment::{CTree, Witness};
use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient; use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient;
use crate::lw_rpc::*; use crate::lw_rpc::*;
use crate::{advance_tree, NETWORK}; use crate::{advance_tree, NETWORK, LWD_URL};
use ff::PrimeField; use ff::PrimeField;
use group::GroupEncoding; use group::GroupEncoding;
use log::info; use log::info;
@ -17,13 +17,9 @@ use zcash_primitives::sapling::note_encryption::try_sapling_compact_note_decrypt
use zcash_primitives::sapling::{Node, Note, PaymentAddress}; use zcash_primitives::sapling::{Node, Note, PaymentAddress};
use zcash_primitives::transaction::components::sapling::CompactOutputDescription; use zcash_primitives::transaction::components::sapling::CompactOutputDescription;
use zcash_primitives::zip32::ExtendedFullViewingKey; use zcash_primitives::zip32::ExtendedFullViewingKey;
use std::collections::HashMap;
const MAX_CHUNK: u32 = 50000; const MAX_CHUNK: u32 = 50000;
pub const LWD_URL: &str = "https://mainnet.lightwalletd.com:9067";
// pub const LWD_URL: &str = "https://testnet.lightwalletd.com:9067";
// pub const LWD_URL: &str = "http://lwd.hanh.me:9067";
// pub const LWD_URL: &str = "https://lwdv3.zecwallet.co";
// pub const LWD_URL: &str = "http://127.0.0.1:9067";
pub async fn get_latest_height( pub async fn get_latest_height(
client: &mut CompactTxStreamerClient<Channel>, client: &mut CompactTxStreamerClient<Channel>,
@ -82,12 +78,18 @@ pub async fn download_chain(
} }
pub struct DecryptNode { pub struct DecryptNode {
fvks: Vec<ExtendedFullViewingKey>, fvks: HashMap<u32, ExtendedFullViewingKey>,
} }
#[derive(Eq, Hash, PartialEq, Copy, Clone)] #[derive(Eq, Hash, PartialEq, Copy, Clone)]
pub struct Nf(pub [u8; 32]); pub struct Nf(pub [u8; 32]);
#[derive(Copy, Clone)]
pub struct NfRef {
pub id_note: u32,
pub account: u32
}
pub struct DecryptedBlock<'a> { pub struct DecryptedBlock<'a> {
pub height: u32, pub height: u32,
pub notes: Vec<DecryptedNote>, pub notes: Vec<DecryptedNote>,
@ -98,6 +100,7 @@ pub struct DecryptedBlock<'a> {
#[derive(Clone)] #[derive(Clone)]
pub struct DecryptedNote { pub struct DecryptedNote {
pub account: u32,
pub ivk: ExtendedFullViewingKey, pub ivk: ExtendedFullViewingKey,
pub note: Note, pub note: Note,
pub pa: PaymentAddress, pub pa: PaymentAddress,
@ -126,7 +129,7 @@ pub fn to_output_description(co: &CompactOutput) -> CompactOutputDescription {
fn decrypt_notes<'a>( fn decrypt_notes<'a>(
block: &'a CompactBlock, block: &'a CompactBlock,
fvks: &[ExtendedFullViewingKey], fvks: &HashMap<u32, ExtendedFullViewingKey>,
) -> DecryptedBlock<'a> { ) -> DecryptedBlock<'a> {
let height = BlockHeight::from_u32(block.height as u32); let height = BlockHeight::from_u32(block.height as u32);
let mut count_outputs = 0u32; let mut count_outputs = 0u32;
@ -140,13 +143,14 @@ fn decrypt_notes<'a>(
} }
for (output_index, co) in vtx.outputs.iter().enumerate() { for (output_index, co) in vtx.outputs.iter().enumerate() {
for fvk in fvks.iter() { for (&account, fvk) in fvks.iter() {
let ivk = &fvk.fvk.vk.ivk(); let ivk = &fvk.fvk.vk.ivk();
let od = to_output_description(co); let od = to_output_description(co);
if let Some((note, pa)) = if let Some((note, pa)) =
try_sapling_compact_note_decryption(&NETWORK, height, ivk, &od) try_sapling_compact_note_decryption(&NETWORK, height, ivk, &od)
{ {
notes.push(DecryptedNote { notes.push(DecryptedNote {
account,
ivk: fvk.clone(), ivk: fvk.clone(),
note, note,
pa, pa,
@ -171,7 +175,7 @@ fn decrypt_notes<'a>(
} }
impl DecryptNode { impl DecryptNode {
pub fn new(fvks: Vec<ExtendedFullViewingKey>) -> DecryptNode { pub fn new(fvks: HashMap<u32, ExtendedFullViewingKey>) -> DecryptNode {
DecryptNode { fvks } DecryptNode { fvks }
} }
@ -324,12 +328,15 @@ pub async fn connect_lightwalletd() -> anyhow::Result<CompactTxStreamerClient<Ch
Ok(client) Ok(client)
} }
pub async fn sync(ivk: &str) -> anyhow::Result<()> { pub async fn sync(fvks: &HashMap<u32, String>) -> anyhow::Result<()> {
let fvks: HashMap<_, _> = fvks.iter().map(|(&account, fvk)| {
let fvk = let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &ivk) decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk)
.unwrap() .unwrap()
.unwrap(); .unwrap();
let decrypter = DecryptNode::new(vec![fvk]); (account, fvk)
}).collect();
let decrypter = DecryptNode::new(fvks);
let mut client = connect_lightwalletd().await?; let mut client = connect_lightwalletd().await?;
let start_height: u32 = crate::NETWORK let start_height: u32 = crate::NETWORK
.activation_height(NetworkUpgrade::Sapling) .activation_height(NetworkUpgrade::Sapling)
@ -361,7 +368,7 @@ pub async fn sync(ivk: &str) -> anyhow::Result<()> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::chain::LWD_URL; use crate::LWD_URL;
#[allow(unused_imports)] #[allow(unused_imports)]
use crate::chain::{ use crate::chain::{
calculate_tree_state_v1, calculate_tree_state_v2, download_chain, get_latest_height, calculate_tree_state_v1, calculate_tree_state_v2, download_chain, get_latest_height,
@ -373,6 +380,8 @@ mod tests {
use std::time::Instant; use std::time::Instant;
use zcash_client_backend::encoding::decode_extended_full_viewing_key; use zcash_client_backend::encoding::decode_extended_full_viewing_key;
use zcash_primitives::consensus::{NetworkUpgrade, Parameters}; use zcash_primitives::consensus::{NetworkUpgrade, Parameters};
use zcash_primitives::zip32::ExtendedFullViewingKey;
use std::collections::HashMap;
#[tokio::test] #[tokio::test]
async fn test_get_latest_height() -> anyhow::Result<()> { async fn test_get_latest_height() -> anyhow::Result<()> {
@ -387,11 +396,13 @@ mod tests {
dotenv::dotenv().unwrap(); dotenv::dotenv().unwrap();
let fvk = dotenv::var("FVK").unwrap(); let fvk = dotenv::var("FVK").unwrap();
let mut fvks: HashMap<u32, ExtendedFullViewingKey> = HashMap::new();
let fvk = let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk) decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk)
.unwrap() .unwrap()
.unwrap(); .unwrap();
let decrypter = DecryptNode::new(vec![fvk]); fvks.insert(1, fvk);
let decrypter = DecryptNode::new(fvks);
let mut client = CompactTxStreamerClient::connect(LWD_URL).await?; let mut client = CompactTxStreamerClient::connect(LWD_URL).await?;
let start_height: u32 = crate::NETWORK let start_height: u32 = crate::NETWORK
.activation_height(NetworkUpgrade::Sapling) .activation_height(NetworkUpgrade::Sapling)

146
src/db.rs
View File

@ -1,4 +1,4 @@
use crate::chain::Nf; use crate::chain::{Nf, NfRef};
use crate::{CTree, Witness}; use crate::{CTree, Witness};
use rusqlite::{params, Connection, OptionalExtension, NO_PARAMS}; use rusqlite::{params, Connection, OptionalExtension, NO_PARAMS};
use std::collections::HashMap; use std::collections::HashMap;
@ -15,6 +15,7 @@ pub struct DbAdapter {
} }
pub struct ReceivedNote { pub struct ReceivedNote {
pub account: u32,
pub height: u32, pub height: u32,
pub output_index: u32, pub output_index: u32,
pub diversifier: Vec<u8>, pub diversifier: Vec<u8>,
@ -41,9 +42,10 @@ impl DbAdapter {
self.connection.execute( self.connection.execute(
"CREATE TABLE IF NOT EXISTS accounts ( "CREATE TABLE IF NOT EXISTS accounts (
id_account INTEGER PRIMARY KEY, id_account INTEGER PRIMARY KEY,
seed TEXT NOT NULL, name TEXT NOT NULL,
sk TEXT NOT NULL UNIQUE, seed TEXT,
ivk TEXT NOT NULL, sk TEXT,
ivk TEXT NOT NULL UNIQUE,
address TEXT NOT NULL)", address TEXT NOT NULL)",
NO_PARAMS, NO_PARAMS,
)?; )?;
@ -60,6 +62,7 @@ impl DbAdapter {
self.connection.execute( self.connection.execute(
"CREATE TABLE IF NOT EXISTS transactions ( "CREATE TABLE IF NOT EXISTS transactions (
id_tx INTEGER PRIMARY KEY, id_tx INTEGER PRIMARY KEY,
account INTEGER NOT NULL,
txid BLOB NOT NULL UNIQUE, txid BLOB NOT NULL UNIQUE,
height INTEGER NOT NULL, height INTEGER NOT NULL,
timestamp INTEGER NOT NULL, timestamp INTEGER NOT NULL,
@ -71,6 +74,7 @@ impl DbAdapter {
self.connection.execute( self.connection.execute(
"CREATE TABLE IF NOT EXISTS received_notes ( "CREATE TABLE IF NOT EXISTS received_notes (
id_note INTEGER PRIMARY KEY, id_note INTEGER PRIMARY KEY,
account INTEGER NOT NULL,
position INTEGER NOT NULL, position INTEGER NOT NULL,
tx INTEGER NOT NULL, tx INTEGER NOT NULL,
height INTEGER NOT NULL, height INTEGER NOT NULL,
@ -97,25 +101,33 @@ impl DbAdapter {
Ok(()) Ok(())
} }
pub fn store_account(&self, seed: &str, sk: &str, ivk: &str, address: &str) -> anyhow::Result<()> { pub fn store_account(&self, name: &str, seed: Option<&str>, sk: Option<&str>, ivk: &str, address: &str) -> anyhow::Result<u32> {
self.connection.execute( self.connection.execute(
"INSERT INTO accounts(seed, sk, ivk, address) VALUES (?1, ?2, ?3, ?4) "INSERT INTO accounts(name, seed, sk, ivk, address) VALUES (?1, ?2, ?3, ?4, ?5)
ON CONFLICT DO NOTHING", ON CONFLICT DO NOTHING",
params![seed, sk, ivk, address], params![name, seed, sk, ivk, address],
)?; )?;
Ok(()) let id_tx: u32 = self.connection.query_row(
"SELECT id_account FROM accounts WHERE sk = ?1",
params![sk],
|row| row.get(0),
)?;
Ok(id_tx)
} }
pub fn has_account(&self, account: u32) -> anyhow::Result<bool> { pub fn get_fvks(&self) -> anyhow::Result<HashMap<u32, String>> {
let r: Option<i32> = self let mut statement = self.connection.prepare("SELECT id_account, ivk FROM accounts")?;
.connection let rows = statement.query_map(NO_PARAMS, |row| {
.query_row( let account: u32 = row.get(0)?;
"SELECT 1 FROM accounts WHERE id_account = ?1", let ivk: String = row.get(1)?;
params![account], Ok((account, ivk))
|row| row.get(0), })?;
) let mut fvks: HashMap<u32, String> = HashMap::new();
.optional()?; for r in rows {
Ok(r.is_some()) let row = r?;
fvks.insert(row.0, row.1);
}
Ok(fvks)
} }
pub fn trim_to_height(&mut self, height: u32) -> anyhow::Result<()> { pub fn trim_to_height(&mut self, height: u32) -> anyhow::Result<()> {
@ -161,16 +173,17 @@ impl DbAdapter {
pub fn store_transaction( pub fn store_transaction(
&self, &self,
txid: &[u8], txid: &[u8],
account: u32,
height: u32, height: u32,
timestamp: u32, timestamp: u32,
tx_index: u32, tx_index: u32,
) -> anyhow::Result<u32> { ) -> anyhow::Result<u32> {
log::debug!("+transaction"); log::debug!("+transaction");
self.connection.execute( self.connection.execute(
"INSERT INTO transactions(txid, height, timestamp, tx_index, value) "INSERT INTO transactions(account, txid, height, timestamp, tx_index, value)
VALUES (?1, ?2, ?3, ?4, 0) VALUES (?1, ?2, ?3, ?4, ?5, 0)
ON CONFLICT DO NOTHING", ON CONFLICT DO NOTHING",
params![txid, height, timestamp, tx_index], params![account, txid, height, timestamp, tx_index],
)?; )?;
let id_tx: u32 = self.connection.query_row( let id_tx: u32 = self.connection.query_row(
"SELECT id_tx FROM transactions WHERE txid = ?1", "SELECT id_tx FROM transactions WHERE txid = ?1",
@ -188,9 +201,9 @@ impl DbAdapter {
position: usize, position: usize,
) -> anyhow::Result<u32> { ) -> anyhow::Result<u32> {
log::debug!("+received_note {}", id_tx); log::debug!("+received_note {}", id_tx);
self.connection.execute("INSERT INTO received_notes(tx, height, position, output_index, diversifier, value, rcm, nf, spent) self.connection.execute("INSERT INTO received_notes(account, tx, height, position, output_index, diversifier, value, rcm, nf, spent)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
ON CONFLICT DO NOTHING", params![id_tx, note.height, position as u32, note.output_index, note.diversifier, note.value as i64, note.rcm, note.nf, note.spent])?; ON CONFLICT DO NOTHING", params![note.account, id_tx, note.height, position as u32, note.output_index, note.diversifier, note.value as i64, note.rcm, note.nf, note.spent])?;
let id_note: u32 = self.connection.query_row( let id_note: u32 = self.connection.query_row(
"SELECT id_note FROM received_notes WHERE tx = ?1 AND output_index = ?2", "SELECT id_note FROM received_notes WHERE tx = ?1 AND output_index = ?2",
params![id_tx, note.output_index], params![id_tx, note.output_index],
@ -226,28 +239,32 @@ impl DbAdapter {
Ok(()) Ok(())
} }
pub fn get_received_note_value(&self, nf: &Nf) -> anyhow::Result<i64> { pub fn get_received_note_value(&self, nf: &Nf) -> anyhow::Result<(u32, i64)> {
let value: i64 = self.connection.query_row( let (account, value) = self.connection.query_row(
"SELECT value FROM received_notes WHERE nf = ?1", "SELECT account, value FROM received_notes WHERE nf = ?1",
params![nf.0.to_vec()], params![nf.0.to_vec()],
|row| row.get(0), |row| {
let account: u32 = row.get(0)?;
let value: i64 = row.get(1)?;
Ok((account, value))
},
)?; )?;
Ok(value) Ok((account, value))
} }
pub fn get_balance(&self) -> anyhow::Result<u64> { pub fn get_balance(&self, account: u32) -> anyhow::Result<u64> {
let balance: Option<i64> = self.connection.query_row( let balance: Option<i64> = self.connection.query_row(
"SELECT SUM(value) FROM received_notes WHERE spent IS NULL OR spent = 0", "SELECT SUM(value) FROM received_notes WHERE (spent IS NULL OR spent = 0) AND account = ?1",
NO_PARAMS, params![account],
|row| row.get(0), |row| row.get(0),
)?; )?;
Ok(balance.unwrap_or(0) as u64) Ok(balance.unwrap_or(0) as u64)
} }
pub fn get_spendable_balance(&self, anchor_height: u32) -> anyhow::Result<u64> { pub fn get_spendable_balance(&self, account: u32, anchor_height: u32) -> anyhow::Result<u64> {
let balance: Option<i64> = self.connection.query_row( let balance: Option<i64> = self.connection.query_row(
"SELECT SUM(value) FROM received_notes WHERE spent IS NULL AND height <= ?1", "SELECT SUM(value) FROM received_notes WHERE spent IS NULL AND height <= ?1 AND account = ?2",
params![anchor_height], params![anchor_height, account],
|row| row.get(0), |row| row.get(0),
)?; )?;
Ok(balance.unwrap_or(0) as u64) Ok(balance.unwrap_or(0) as u64)
@ -309,18 +326,23 @@ impl DbAdapter {
}) })
} }
pub fn get_nullifiers(&self) -> anyhow::Result<HashMap<Nf, u32>> { pub fn get_nullifiers(&self) -> anyhow::Result<HashMap<Nf, NfRef>> {
let mut statement = self let mut statement = self
.connection .connection
.prepare("SELECT id_note, nf FROM received_notes WHERE spent IS NULL OR spent = 0")?; .prepare("SELECT id_note, account, nf FROM received_notes WHERE spent IS NULL OR spent = 0")?;
let nfs_res = statement.query_map(NO_PARAMS, |row| { let nfs_res = statement.query_map(NO_PARAMS, |row| {
let id_note: u32 = row.get(0)?; let id_note: u32 = row.get(0)?;
let nf_vec: Vec<u8> = row.get(1)?; let account: u32 = row.get(1)?;
let nf_vec: Vec<u8> = row.get(2)?;
let mut nf = [0u8; 32]; let mut nf = [0u8; 32];
nf.clone_from_slice(&nf_vec); nf.clone_from_slice(&nf_vec);
Ok((id_note, nf)) let nf_ref = NfRef {
id_note,
account
};
Ok((nf_ref, nf))
})?; })?;
let mut nfs: HashMap<Nf, u32> = HashMap::new(); let mut nfs: HashMap<Nf, NfRef> = HashMap::new();
for n in nfs_res { for n in nfs_res {
let n = n?; let n = n?;
nfs.insert(Nf(n.1), n.0); nfs.insert(Nf(n.1), n.0);
@ -329,11 +351,11 @@ impl DbAdapter {
Ok(nfs) Ok(nfs)
} }
pub fn get_nullifier_amounts(&self) -> anyhow::Result<HashMap<Vec<u8>, u64>> { pub fn get_nullifier_amounts(&self, account: u32) -> anyhow::Result<HashMap<Vec<u8>, u64>> {
let mut statement = self let mut statement = self
.connection .connection
.prepare("SELECT value, nf FROM received_notes WHERE spent IS NULL OR spent = 0")?; .prepare("SELECT value, nf FROM received_notes WHERE account = ?1 AND (spent IS NULL OR spent = 0)")?;
let nfs_res = statement.query_map(NO_PARAMS, |row| { let nfs_res = statement.query_map(params![account], |row| {
let amount: i64 = row.get(0)?; let amount: i64 = row.get(0)?;
let nf: Vec<u8> = row.get(1)?; let nf: Vec<u8> = row.get(1)?;
Ok((amount, nf)) Ok((amount, nf))
@ -349,15 +371,16 @@ impl DbAdapter {
pub fn get_spendable_notes( pub fn get_spendable_notes(
&self, &self,
account: u32,
anchor_height: u32, anchor_height: u32,
fvk: &ExtendedFullViewingKey, fvk: &ExtendedFullViewingKey,
) -> anyhow::Result<Vec<SpendableNote>> { ) -> anyhow::Result<Vec<SpendableNote>> {
let mut statement = self.connection.prepare( let mut statement = self.connection.prepare(
"SELECT id_note, diversifier, value, rcm, witness FROM received_notes r, sapling_witnesses w WHERE spent IS NULL "SELECT id_note, diversifier, value, rcm, witness FROM received_notes r, sapling_witnesses w WHERE spent IS NULL AND account = ?2
AND w.height = ( AND w.height = (
SELECT MAX(height) FROM sapling_witnesses WHERE height <= ?1 SELECT MAX(height) FROM sapling_witnesses WHERE height <= ?1
) AND r.id_note = w.note")?; ) AND r.id_note = w.note")?;
let notes = statement.query_map(params![anchor_height], |row| { let notes = statement.query_map(params![anchor_height, account], |row| {
let id_note: u32 = row.get(0)?; let id_note: u32 = row.get(0)?;
let diversifier: Vec<u8> = row.get(1)?; let diversifier: Vec<u8> = row.get(1)?;
@ -401,32 +424,34 @@ impl DbAdapter {
Ok(()) Ok(())
} }
pub fn get_seed(&self, account: u32) -> anyhow::Result<String> { pub fn get_backup(&self, account: u32) -> anyhow::Result<(Option<String>, Option<String>, String)> {
log::debug!("+get_seed"); log::debug!("+get_backup");
let ivk = self.connection.query_row( let (seed, sk, ivk) = self.connection.query_row(
"SELECT seed FROM accounts WHERE id_account = ?1", "SELECT seed, sk, ivk FROM accounts WHERE id_account = ?1",
params![account], params![account],
|row| { |row| {
let seed: Option<String> = row.get(0)?;
let sk: Option<String> = row.get(0)?;
let ivk: String = row.get(0)?; let ivk: String = row.get(0)?;
Ok(ivk) Ok((seed, sk, ivk))
}, },
)?; )?;
log::debug!("-get_seed"); log::debug!("-get_backup");
Ok(ivk) Ok((seed, sk, ivk))
} }
pub fn get_sk(&self, account: u32) -> anyhow::Result<String> { pub fn get_sk(&self, account: u32) -> anyhow::Result<String> {
log::debug!("+get_sk"); log::info!("+get_sk");
let ivk = self.connection.query_row( let sk = self.connection.query_row(
"SELECT sk FROM accounts WHERE id_account = ?1", "SELECT sk FROM accounts WHERE id_account = ?1",
params![account], params![account],
|row| { |row| {
let ivk: String = row.get(0)?; let sk: String = row.get(0)?;
Ok(ivk) Ok(sk)
}, },
)?; )?;
log::debug!("-get_sk"); log::info!("-get_sk");
Ok(ivk) Ok(sk)
} }
pub fn get_ivk(&self, account: u32) -> anyhow::Result<String> { pub fn get_ivk(&self, account: u32) -> anyhow::Result<String> {
@ -456,9 +481,10 @@ mod tests {
db.trim_to_height(0).unwrap(); db.trim_to_height(0).unwrap();
db.store_block(1, &[0u8; 32], 0, &CTree::new()).unwrap(); db.store_block(1, &[0u8; 32], 0, &CTree::new()).unwrap();
let id_tx = db.store_transaction(&[0; 32], 1, 0, 20).unwrap(); let id_tx = db.store_transaction(&[0; 32], 1, 1, 0, 20).unwrap();
db.store_received_note( db.store_received_note(
&ReceivedNote { &ReceivedNote {
account: 1,
height: 1, height: 1,
output_index: 0, output_index: 0,
diversifier: vec![], diversifier: vec![],
@ -485,7 +511,7 @@ mod tests {
#[test] #[test]
fn test_balance() { fn test_balance() {
let db = DbAdapter::new(DEFAULT_DB_PATH).unwrap(); let db = DbAdapter::new(DEFAULT_DB_PATH).unwrap();
let balance = db.get_balance().unwrap(); let balance = db.get_balance(1).unwrap();
println!("{}", balance); println!("{}", balance);
} }
} }

View File

@ -1,5 +1,4 @@
use crate::NETWORK; use crate::NETWORK;
use anyhow::anyhow;
use bip39::{Language, Mnemonic, Seed}; use bip39::{Language, Mnemonic, Seed};
use zcash_client_backend::encoding::{ use zcash_client_backend::encoding::{
decode_extended_full_viewing_key, decode_extended_spending_key, decode_extended_full_viewing_key, decode_extended_spending_key,
@ -8,8 +7,40 @@ use zcash_client_backend::encoding::{
use zcash_primitives::consensus::Parameters; use zcash_primitives::consensus::Parameters;
use zcash_primitives::zip32::{ChildIndex, ExtendedFullViewingKey, ExtendedSpendingKey}; use zcash_primitives::zip32::{ChildIndex, ExtendedFullViewingKey, ExtendedSpendingKey};
pub fn get_secret_key(seed: &str) -> anyhow::Result<String> { pub fn decode_key(key: &str) -> anyhow::Result<(Option<String>, Option<String>, String, String)> {
let mnemonic = Mnemonic::from_phrase(&seed, Language::English)?; let res =
if let Ok(mnemonic) = Mnemonic::from_phrase(&key, Language::English) {
let (sk, ivk, pa) = derive_secret_key(&mnemonic)?;
Ok((Some(key.to_string()), Some(sk), ivk, pa))
}
else if let Ok(Some(sk)) = decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &key) {
let (ivk, pa) = derive_viewing_key(&sk)?;
Ok((None, Some(key.to_string()), ivk, pa))
}
else if let Ok(Some(fvk)) = decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &key) {
let pa = derive_address(&fvk)?;
Ok((None, None, key.to_string(), pa))
}
else {
Err(anyhow::anyhow!("Not a valid key"))
};
res
}
pub fn is_valid_key(key: &str) -> bool {
if Mnemonic::from_phrase(&key, Language::English).is_ok() {
return true;
}
if decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &key).is_ok() {
return true;
}
if decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &key).is_ok() {
return true;
}
false
}
pub fn derive_secret_key(mnemonic: &Mnemonic) -> anyhow::Result<(String, String, String)> {
let seed = Seed::new(&mnemonic, ""); let seed = Seed::new(&mnemonic, "");
let master = ExtendedSpendingKey::master(seed.as_bytes()); let master = ExtendedSpendingKey::master(seed.as_bytes());
let path = [ let path = [
@ -18,28 +49,21 @@ pub fn get_secret_key(seed: &str) -> anyhow::Result<String> {
ChildIndex::Hardened(0), ChildIndex::Hardened(0),
]; ];
let extsk = ExtendedSpendingKey::from_path(&master, &path); let extsk = ExtendedSpendingKey::from_path(&master, &path);
let spending_key = let sk =
encode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &extsk); encode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &extsk);
Ok(spending_key) let (fvk, pa) = derive_viewing_key(&extsk)?;
Ok((sk, fvk, pa))
} }
pub fn get_viewing_key(secret_key: &str) -> anyhow::Result<String> { pub fn derive_viewing_key(extsk: &ExtendedSpendingKey) -> anyhow::Result<(String, String)> {
let extsk = let fvk = ExtendedFullViewingKey::from(extsk);
decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), secret_key)? let pa = derive_address(&fvk)?;
.ok_or(anyhow!("Invalid Secret Key"))?; let fvk = encode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk);
let fvk = ExtendedFullViewingKey::from(&extsk); Ok((fvk, pa))
let viewing_key =
encode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk);
Ok(viewing_key)
} }
pub fn get_address(viewing_key: &str) -> anyhow::Result<String> { pub fn derive_address(fvk: &ExtendedFullViewingKey) -> anyhow::Result<String> {
let fvk = decode_extended_full_viewing_key(
NETWORK.hrp_sapling_extended_full_viewing_key(),
&viewing_key,
)?
.ok_or(anyhow!("Invalid Viewing Key"))?;
let (_, payment_address) = fvk.default_address().unwrap(); let (_, payment_address) = fvk.default_address().unwrap();
let address = encode_payment_address(NETWORK.hrp_sapling_payment_address(), &payment_address); let address = encode_payment_address(NETWORK.hrp_sapling_payment_address(), &payment_address);
Ok(address) Ok(address)

View File

@ -3,7 +3,16 @@ use zcash_primitives::consensus::Network;
#[path = "generated/cash.z.wallet.sdk.rpc.rs"] #[path = "generated/cash.z.wallet.sdk.rpc.rs"]
pub mod lw_rpc; pub mod lw_rpc;
pub const NETWORK: Network = Network::MainNetwork; pub const NETWORK: Network = Network::TestNetwork;
// Mainnet
// pub const LWD_URL: &str = "https://mainnet.lightwalletd.com:9067";
// pub const LWD_URL: &str = "https://lwdv3.zecwallet.co";
// Testnet
pub const LWD_URL: &str = "https://testnet.lightwalletd.com:9067";
// pub const LWD_URL: &str = "http://lwd.hanh.me:9067";
// pub const LWD_URL: &str = "http://127.0.0.1:9067";
mod builder; mod builder;
mod chain; mod chain;
@ -18,14 +27,14 @@ mod wallet;
pub use crate::builder::advance_tree; pub use crate::builder::advance_tree;
pub use crate::chain::{ pub use crate::chain::{
calculate_tree_state_v2, connect_lightwalletd, download_chain, get_latest_height, sync, calculate_tree_state_v2, connect_lightwalletd, download_chain, get_latest_height, sync,
DecryptNode, LWD_URL, ChainError DecryptNode, ChainError
}; };
pub use crate::commitment::{CTree, Witness}; pub use crate::commitment::{CTree, Witness};
pub use crate::db::DbAdapter; pub use crate::db::DbAdapter;
pub use crate::key::{get_address, get_secret_key, get_viewing_key};
pub use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient; pub use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient;
pub use crate::lw_rpc::*; pub use crate::lw_rpc::*;
pub use crate::mempool::MemPool; pub use crate::mempool::MemPool;
pub use crate::scan::{latest_height, scan_all, sync_async}; pub use crate::scan::{latest_height, scan_all, sync_async};
pub use crate::wallet::{Wallet, WalletBalance, DEFAULT_ACCOUNT}; pub use crate::wallet::{Wallet, WalletBalance};
pub use crate::print::*; pub use crate::print::*;
pub use crate::key::is_valid_key;

View File

@ -1,7 +1,7 @@
use bip39::{Language, Mnemonic}; use bip39::{Language, Mnemonic};
use rand::rngs::OsRng; use rand::rngs::OsRng;
use rand::RngCore; use rand::RngCore;
use sync::{DbAdapter, Wallet, DEFAULT_ACCOUNT, ChainError, Witness, print_witness2}; use sync::{DbAdapter, Wallet, ChainError, Witness, print_witness2};
use rusqlite::NO_PARAMS; use rusqlite::NO_PARAMS;
const DB_NAME: &str = "zec.db"; const DB_NAME: &str = "zec.db";
@ -23,18 +23,21 @@ async fn test() -> anyhow::Result<()> {
log::info!("Height = {}", height); log::info!("Height = {}", height);
}; };
let wallet = Wallet::new(DB_NAME); let wallet = Wallet::new(DB_NAME);
wallet.new_account_with_seed(&seed).unwrap(); wallet.new_account_with_key("test", &seed).unwrap();
let res = wallet.sync(DEFAULT_ACCOUNT, progress).await; let res = wallet.sync(progress).await;
if let Err(err) = res { if let Err(err) = res {
if let Some(_) = err.downcast_ref::<ChainError>() { if let Some(_) = err.downcast_ref::<ChainError>() {
println!("REORG"); println!("REORG");
} }
else {
panic!(err);
} }
// let tx_id = wallet }
// .send_payment(DEFAULT_ACCOUNT, &address, 50000) let tx_id = wallet
// .await .send_payment(1, &address, 50000)
// .unwrap(); .await
// println!("TXID = {}", tx_id); .unwrap();
println!("TXID = {}", tx_id);
Ok(()) Ok(())
} }
@ -57,7 +60,7 @@ fn test_rewind() {
#[allow(dead_code)] #[allow(dead_code)]
fn test_get_balance() { fn test_get_balance() {
let db = DbAdapter::new(DB_NAME).unwrap(); let db = DbAdapter::new(DB_NAME).unwrap();
let balance = db.get_balance().unwrap(); let balance = db.get_balance(1).unwrap();
println!("Balance = {}", (balance as f64) / 100_000_000.0); println!("Balance = {}", (balance as f64) / 100_000_000.0);
} }
@ -77,6 +80,7 @@ fn test_invalid_witness() {
print_witness2(&w); print_witness2(&w);
} }
#[allow(dead_code)]
fn w() { fn w() {
let db = DbAdapter::new("zec.db").unwrap(); let db = DbAdapter::new("zec.db").unwrap();
// let w_b: Vec<u8> = db.connection.query_row("SELECT witness FROM sapling_witnesses WHERE note = 66 AND height = 1466097", NO_PARAMS, |row| row.get(0)).unwrap(); // let w_b: Vec<u8> = db.connection.query_row("SELECT witness FROM sapling_witnesses WHERE note = 66 AND height = 1466097", NO_PARAMS, |row| row.get(0)).unwrap();

View File

@ -21,6 +21,7 @@ struct MemPoolTransacton {
pub struct MemPool { pub struct MemPool {
db_path: String, db_path: String,
account: u32,
ivk: Option<SaplingIvk>, ivk: Option<SaplingIvk>,
height: BlockHeight, height: BlockHeight,
transactions: HashMap<Vec<u8>, MemPoolTransacton>, transactions: HashMap<Vec<u8>, MemPoolTransacton>,
@ -32,6 +33,7 @@ impl MemPool {
pub fn new(db_path: &str) -> MemPool { pub fn new(db_path: &str) -> MemPool {
MemPool { MemPool {
db_path: db_path.to_string(), db_path: db_path.to_string(),
account: 0,
ivk: None, ivk: None,
height: BlockHeight::from(0), height: BlockHeight::from(0),
transactions: HashMap::new(), transactions: HashMap::new(),
@ -43,11 +45,13 @@ impl MemPool {
pub fn set_account(&mut self, account: u32) -> anyhow::Result<()> { pub fn set_account(&mut self, account: u32) -> anyhow::Result<()> {
let db = DbAdapter::new(&self.db_path)?; let db = DbAdapter::new(&self.db_path)?;
let ivk = db.get_ivk(account)?; let ivk = db.get_ivk(account)?;
self.account = account;
self.set_ivk(&ivk); self.set_ivk(&ivk);
self.clear()?;
Ok(()) Ok(())
} }
pub fn set_ivk(&mut self, ivk: &str) { fn set_ivk(&mut self, ivk: &str) {
let fvk = let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &ivk) decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &ivk)
.unwrap() .unwrap()
@ -56,30 +60,30 @@ impl MemPool {
self.ivk = Some(ivk); self.ivk = Some(ivk);
} }
pub async fn scan(&mut self) -> anyhow::Result<()> { pub async fn scan(&mut self) -> anyhow::Result<i64> {
if self.ivk.is_some() { if self.ivk.is_some() {
let ivk = self.ivk.as_ref().unwrap().clone(); let ivk = self.ivk.as_ref().unwrap().clone();
let mut client = connect_lightwalletd().await?; let mut client = connect_lightwalletd().await?;
let height = BlockHeight::from(get_latest_height(&mut client).await?); let height = BlockHeight::from(get_latest_height(&mut client).await?);
if self.height != height { if self.height != height {
// New blocks invalidate the mempool // New blocks invalidate the mempool
let db = DbAdapter::new(&self.db_path)?; self.clear()?;
self.clear(&db)?;
} }
self.height = height; self.height = height;
self.update(&mut client, &ivk).await?; self.update(&mut client, &ivk).await?;
} }
Ok(()) Ok(self.balance)
} }
pub fn get_unconfirmed_balance(&self) -> i64 { pub fn get_unconfirmed_balance(&self) -> i64 {
self.balance self.balance
} }
fn clear(&mut self, db: &DbAdapter) -> anyhow::Result<()> { fn clear(&mut self) -> anyhow::Result<()> {
let db = DbAdapter::new(&self.db_path)?;
self.height = BlockHeight::from_u32(0); self.height = BlockHeight::from_u32(0);
self.nfs = db.get_nullifier_amounts()?; self.nfs = db.get_nullifier_amounts(self.account)?;
self.transactions.clear(); self.transactions.clear();
self.balance = 0; self.balance = 0;
Ok(()) Ok(())

View File

@ -1,5 +1,5 @@
use crate::builder::BlockProcessor; use crate::builder::BlockProcessor;
use crate::chain::Nf; use crate::chain::{Nf, NfRef};
use crate::db::{DbAdapter, ReceivedNote}; use crate::db::{DbAdapter, ReceivedNote};
use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient; use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient;
use crate::{ use crate::{
@ -18,9 +18,12 @@ use zcash_primitives::consensus::{NetworkUpgrade, Parameters};
use zcash_primitives::sapling::Node; use zcash_primitives::sapling::Node;
use zcash_primitives::zip32::ExtendedFullViewingKey; use zcash_primitives::zip32::ExtendedFullViewingKey;
use std::panic; use std::panic;
use std::collections::HashMap;
pub async fn scan_all(fvks: &[ExtendedFullViewingKey]) -> anyhow::Result<()> { pub async fn scan_all(fvks: &[ExtendedFullViewingKey]) -> anyhow::Result<()> {
let decrypter = DecryptNode::new(fvks.to_vec()); let fvks: HashMap<_, _> = fvks.iter().enumerate().map(|(i, fvk)|
(i as u32, fvk.clone())).collect();
let decrypter = DecryptNode::new(fvks);
let total_start = Instant::now(); let total_start = Instant::now();
let mut client = CompactTxStreamerClient::connect(LWD_URL).await?; let mut client = CompactTxStreamerClient::connect(LWD_URL).await?;
@ -68,27 +71,33 @@ impl std::fmt::Debug for Blocks {
pub type ProgressCallback = Arc<Mutex<dyn Fn(u32) + Send>>; pub type ProgressCallback = Arc<Mutex<dyn Fn(u32) + Send>>;
pub async fn sync_async( pub async fn sync_async(
ivk: &str,
chunk_size: u32, chunk_size: u32,
db_path: &str, db_path: &str,
target_height_offset: u32, target_height_offset: u32,
progress_callback: ProgressCallback, progress_callback: ProgressCallback,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let db_path = db_path.to_string(); let db_path = db_path.to_string();
let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &ivk)?
.ok_or_else(|| anyhow::anyhow!("Invalid key"))?;
let decrypter = DecryptNode::new(vec![fvk]);
let mut client = connect_lightwalletd().await?; let mut client = connect_lightwalletd().await?;
let (start_height, mut prev_hash) = { let (start_height, mut prev_hash, fvks) = {
let db = DbAdapter::new(&db_path)?; let db = DbAdapter::new(&db_path)?;
let height = db.get_db_height()?; let height = db.get_db_height()?;
(height, db.get_db_hash(height)?) let hash = db.get_db_hash(height)?;
let fvks = db.get_fvks()?;
(height, hash, fvks)
}; };
let end_height = get_latest_height(&mut client).await?; let end_height = get_latest_height(&mut client).await?;
let end_height = (end_height - target_height_offset).max(start_height); let end_height = (end_height - target_height_offset).max(start_height);
let fvks: HashMap<_, _> = fvks.iter().map(|(&account, fvk)| {
let fvk =
decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &fvk)
.unwrap()
.unwrap();
(account, fvk)
}).collect();
let decrypter = DecryptNode::new(fvks);
let (downloader_tx, mut download_rx) = mpsc::channel::<Range<u32>>(2); let (downloader_tx, mut download_rx) = mpsc::channel::<Range<u32>>(2);
let (processor_tx, mut processor_rx) = mpsc::channel::<Blocks>(1); let (processor_tx, mut processor_rx) = mpsc::channel::<Blocks>(1);
@ -133,9 +142,9 @@ pub async fn sync_async(
for b in dec_blocks.iter() { for b in dec_blocks.iter() {
let mut my_nfs: Vec<Nf> = vec![]; let mut my_nfs: Vec<Nf> = vec![];
for nf in b.spends.iter() { for nf in b.spends.iter() {
if let Some(&id) = nfs.get(nf) { if let Some(&nf_ref) = nfs.get(nf) {
log::info!("NF FOUND {} {}", id, b.height); log::info!("NF FOUND {} {}", nf_ref.id_note, b.height);
db.mark_spent(id, b.height)?; db.mark_spent(nf_ref.id_note, b.height)?;
my_nfs.push(*nf); my_nfs.push(*nf);
} }
} }
@ -152,12 +161,14 @@ pub async fn sync_async(
let id_tx = db.store_transaction( let id_tx = db.store_transaction(
&n.txid, &n.txid,
n.account,
n.height, n.height,
b.compact_block.time, b.compact_block.time,
n.tx_index as u32, n.tx_index as u32,
)?; )?;
let id_note = db.store_received_note( let id_note = db.store_received_note(
&ReceivedNote { &ReceivedNote {
account: n.account,
height: n.height, height: n.height,
output_index: n.output_index as u32, output_index: n.output_index as u32,
diversifier: n.pa.diversifier().0.to_vec(), diversifier: n.pa.diversifier().0.to_vec(),
@ -170,7 +181,7 @@ pub async fn sync_async(
n.position_in_block, n.position_in_block,
)?; )?;
db.add_value(id_tx, note.value as i64)?; db.add_value(id_tx, note.value as i64)?;
nfs.insert(Nf(nf.0), id_note); nfs.insert(Nf(nf.0), NfRef { id_note, account: n.account });
let w = Witness::new(p as usize, id_note, Some(n.clone())); let w = Witness::new(p as usize, id_note, Some(n.clone()));
witnesses.push(w); witnesses.push(w);
@ -183,10 +194,11 @@ pub async fn sync_async(
nf.copy_from_slice(&cs.nf); nf.copy_from_slice(&cs.nf);
let nf = Nf(nf); let nf = Nf(nf);
if my_nfs.contains(&nf) { if my_nfs.contains(&nf) {
let note_value = db.get_received_note_value(&nf)?; let (account, note_value) = db.get_received_note_value(&nf)?;
let txid = &*tx.hash; let txid = &*tx.hash;
let id_tx = db.store_transaction( let id_tx = db.store_transaction(
txid, txid,
account,
b.height, b.height,
b.compact_block.time, b.compact_block.time,
tx_index as u32, tx_index as u32,

View File

@ -1,7 +1,6 @@
use crate::chain::send_transaction; use crate::chain::send_transaction;
use crate::mempool::MemPool;
use crate::scan::ProgressCallback; use crate::scan::ProgressCallback;
use crate::{connect_lightwalletd, get_address, get_latest_height, get_secret_key, get_viewing_key, DbAdapter, NETWORK, BlockId, CTree}; use crate::{connect_lightwalletd, get_latest_height, DbAdapter, NETWORK, BlockId, CTree};
use anyhow::Context; use anyhow::Context;
use bip39::{Language, Mnemonic}; use bip39::{Language, Mnemonic};
use rand::prelude::SliceRandom; use rand::prelude::SliceRandom;
@ -11,7 +10,7 @@ use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use zcash_client_backend::address::RecipientAddress; use zcash_client_backend::address::RecipientAddress;
use zcash_client_backend::data_api::wallet::ANCHOR_OFFSET; use zcash_client_backend::data_api::wallet::ANCHOR_OFFSET;
use zcash_client_backend::encoding::{decode_extended_spending_key, decode_payment_address}; use zcash_client_backend::encoding::decode_extended_spending_key;
use zcash_params::{OUTPUT_PARAMS, SPEND_PARAMS}; use zcash_params::{OUTPUT_PARAMS, SPEND_PARAMS};
use zcash_primitives::consensus::{BlockHeight, BranchId, Parameters}; use zcash_primitives::consensus::{BlockHeight, BranchId, Parameters};
use zcash_primitives::transaction::builder::Builder; use zcash_primitives::transaction::builder::Builder;
@ -20,8 +19,8 @@ use zcash_primitives::transaction::components::Amount;
use zcash_primitives::zip32::ExtendedFullViewingKey; use zcash_primitives::zip32::ExtendedFullViewingKey;
use zcash_proofs::prover::LocalTxProver; use zcash_proofs::prover::LocalTxProver;
use tonic::Request; use tonic::Request;
use crate::key::{is_valid_key, decode_key};
pub const DEFAULT_ACCOUNT: u32 = 1;
const DEFAULT_CHUNK_SIZE: u32 = 100_000; const DEFAULT_CHUNK_SIZE: u32 = 100_000;
pub struct Wallet { pub struct Wallet {
@ -59,48 +58,48 @@ impl Wallet {
} }
} }
pub fn valid_seed(seed: &str) -> bool { pub fn valid_key(key: &str) -> bool {
get_secret_key(&seed).is_ok() is_valid_key(key)
} }
pub fn valid_address(address: &str) -> bool { pub fn valid_address(address: &str) -> bool {
decode_payment_address(NETWORK.hrp_sapling_payment_address(), address).is_ok() let recipient = RecipientAddress::decode(&NETWORK, address);
recipient.is_some()
} }
pub fn new_seed(&self) -> anyhow::Result<()> { pub fn new_account(&self, name: &str, data: &str) -> anyhow::Result<u32> {
if data.is_empty() {
let mut entropy = [0u8; 32]; let mut entropy = [0u8; 32];
OsRng.fill_bytes(&mut entropy); OsRng.fill_bytes(&mut entropy);
let mnemonic = Mnemonic::from_entropy(&entropy, Language::English)?; let mnemonic = Mnemonic::from_entropy(&entropy, Language::English)?;
let seed = mnemonic.phrase(); let seed = mnemonic.phrase();
self.new_account_with_seed(seed)?; self.new_account_with_key(name, seed)
Ok(()) }
else {
self.new_account_with_key(name, data)
}
} }
pub fn get_seed(&self, account: u32) -> anyhow::Result<String> { pub fn get_backup(&self, account: u32) -> anyhow::Result<String> {
self.db.get_seed(account) let (seed, sk, ivk) = self.db.get_backup(account)?;
if let Some(seed) = seed { return Ok(seed); }
if let Some(sk) = sk { return Ok(sk); }
Ok(ivk)
} }
pub fn has_account(&self, account: u32) -> anyhow::Result<bool> { pub fn new_account_with_key(&self, name: &str, key: &str) -> anyhow::Result<u32> {
self.db.has_account(account) let (seed, sk, ivk, pa) = decode_key(key)?;
} let account = self.db.store_account(name, seed.as_deref(), sk.as_deref(), &ivk, &pa)?;
Ok(account)
pub fn new_account_with_seed(&self, seed: &str) -> anyhow::Result<()> {
let sk = get_secret_key(&seed).unwrap();
let vk = get_viewing_key(&sk).unwrap();
let pa = get_address(&vk).unwrap();
self.db.store_account(seed, &sk, &vk, &pa)?;
Ok(())
} }
async fn scan_async( async fn scan_async(
ivk: &str,
db_path: &str, db_path: &str,
chunk_size: u32, chunk_size: u32,
target_height_offset: u32, target_height_offset: u32,
progress_callback: ProgressCallback, progress_callback: ProgressCallback,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
crate::scan::sync_async( crate::scan::sync_async(
ivk,
chunk_size, chunk_size,
db_path, db_path,
target_height_offset, target_height_offset,
@ -118,22 +117,19 @@ impl Wallet {
// Not a method in order to avoid locking the instance // Not a method in order to avoid locking the instance
pub async fn sync_ex( pub async fn sync_ex(
db_path: &str, db_path: &str,
ivk: &str,
progress_callback: impl Fn(u32) + Send + 'static, progress_callback: impl Fn(u32) + Send + 'static,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cb = Arc::new(Mutex::new(progress_callback)); let cb = Arc::new(Mutex::new(progress_callback));
Self::scan_async(&ivk, db_path, DEFAULT_CHUNK_SIZE, 10, cb.clone()).await?; Self::scan_async(db_path, DEFAULT_CHUNK_SIZE, 10, cb.clone()).await?;
Self::scan_async(&ivk, db_path, DEFAULT_CHUNK_SIZE, 0, cb.clone()).await?; Self::scan_async(db_path, DEFAULT_CHUNK_SIZE, 0, cb.clone()).await?;
Ok(()) Ok(())
} }
pub async fn sync( pub async fn sync(
&self, &self,
account: u32,
progress_callback: impl Fn(u32) + Send + 'static, progress_callback: impl Fn(u32) + Send + 'static,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let ivk = self.get_ivk(account)?; Self::sync_ex(&self.db_path, progress_callback).await
Self::sync_ex(&self.db_path, &ivk, progress_callback).await
} }
pub async fn skip_to_last_height(&self) -> anyhow::Result<()> { pub async fn skip_to_last_height(&self) -> anyhow::Result<()> {
@ -155,20 +151,6 @@ impl Wallet {
self.db.trim_to_height(height) self.db.trim_to_height(height)
} }
pub async fn get_balance(&self, mempool: &MemPool) -> anyhow::Result<WalletBalance> {
let last_height = Self::get_latest_height().await?;
let anchor_height = last_height - ANCHOR_OFFSET;
let confirmed = self.db.get_balance()?;
let unconfirmed = mempool.get_unconfirmed_balance();
let spendable = self.db.get_spendable_balance(anchor_height)?;
Ok(WalletBalance {
confirmed,
unconfirmed,
spendable,
})
}
pub async fn send_payment( pub async fn send_payment(
&self, &self,
account: u32, account: u32,
@ -193,7 +175,7 @@ impl Wallet {
.ok_or_else(|| anyhow::anyhow!("No spendable notes"))?; .ok_or_else(|| anyhow::anyhow!("No spendable notes"))?;
let anchor_height = anchor_height.min(last_height - ANCHOR_OFFSET); let anchor_height = anchor_height.min(last_height - ANCHOR_OFFSET);
log::info!("Anchor = {}", anchor_height); log::info!("Anchor = {}", anchor_height);
let mut notes = self.db.get_spendable_notes(anchor_height, &extfvk)?; let mut notes = self.db.get_spendable_notes(account, anchor_height, &extfvk)?;
notes.shuffle(&mut OsRng); notes.shuffle(&mut OsRng);
log::info!("Spendable notes = {}", notes.len()); log::info!("Spendable notes = {}", notes.len());
@ -260,7 +242,8 @@ impl Wallet {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::wallet::Wallet; use crate::wallet::Wallet;
use crate::{get_address, get_secret_key, get_viewing_key}; use crate::key::derive_secret_key;
use bip39::{Mnemonic, Language};
#[tokio::test] #[tokio::test]
async fn test_wallet_seed() { async fn test_wallet_seed() {
@ -269,7 +252,7 @@ mod tests {
let seed = dotenv::var("SEED").unwrap(); let seed = dotenv::var("SEED").unwrap();
let wallet = Wallet::new("zec.db"); let wallet = Wallet::new("zec.db");
wallet.new_account_with_seed(&seed).unwrap(); wallet.new_account_with_key("test", &seed).unwrap();
} }
#[tokio::test] #[tokio::test]
@ -278,14 +261,11 @@ mod tests {
env_logger::init(); env_logger::init();
let seed = dotenv::var("SEED").unwrap(); let seed = dotenv::var("SEED").unwrap();
let sk = get_secret_key(&seed).unwrap(); let (sk, vk, pa) = derive_secret_key(&Mnemonic::from_phrase(&seed, Language::English).unwrap()).unwrap();
let vk = get_viewing_key(&sk).unwrap(); println!("{} {} {}", sk, vk, pa);
println!("{}", vk); // let wallet = Wallet::new("zec.db");
let pa = get_address(&vk).unwrap(); //
println!("{}", pa); // let tx_id = wallet.send_payment(1, &pa, 1000).await.unwrap();
let wallet = Wallet::new("zec.db"); // println!("TXID = {}", tx_id);
let tx_id = wallet.send_payment(1, &pa, 1000).await.unwrap();
println!("TXID = {}", tx_id);
} }
} }