diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 95751a74..5fd3ebd3 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -23,13 +23,17 @@ use zcash_client_backend::{ wallet::OvkPolicy, }; use zcash_client_sqlite::{ - wallet::init::{init_accounts_table, init_blocks_table, init_data_database}, + wallet::{ + init::{init_accounts_table, init_blocks_table, init_wallet_db}, + get_balance, + get_balance_at, + }, BlockDB, NoteId, WalletDB, - wallet::{UnspentTransactionOutput, get_utxos}, + chain::{UnspentTransactionOutput, get_all_utxos, get_confirmed_utxos_for_address} }; use zcash_primitives::{ block::BlockHash, - consensus::{BlockHeight, BranchId, Parameters}, + consensus::{self,BlockHeight, BranchId, Parameters}, note_encryption::Memo, transaction::components::{ amount::DEFAULT_FEE, @@ -92,13 +96,13 @@ pub const NETWORK: MainNetwork = MAIN_NETWORK; pub const NETWORK: TestNetwork = TEST_NETWORK; -fn wallet_db(db_data: *const u8, - db_data_len: usize) -> Result { +fn wallet_db(params: &P,db_data: *const u8, + db_data_len: usize) -> Result, failure::Error> { let db_data = Path::new(OsStr::from_bytes(unsafe { slice::from_raw_parts(db_data, db_data_len) })); - WalletDB::for_path(db_data) + WalletDB::for_path(db_data, *params) .map_err(|e| format_err!("Error opening wallet database connection: {}", e)) } @@ -139,8 +143,8 @@ pub extern "C" fn zcashlc_init_data_database(db_data: *const u8, db_data_len: us slice::from_raw_parts(db_data, db_data_len) })); - WalletDB::for_path(db_data) - .map(|db| init_data_database(&db)) + WalletDB::for_path(db_data, NETWORK) + .map(|db| init_wallet_db(&db)) .map(|_| 1) .map_err(|e| format_err!("Error while initializing data DB: {}", e)) }); @@ -163,7 +167,7 @@ pub extern "C" fn zcashlc_init_accounts_table( capacity_ret: *mut usize, ) -> *mut *mut c_char { let res = catch_panic(|| { - let db_data = wallet_db(db_data, db_data_len)?; + let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let seed = unsafe { slice::from_raw_parts(seed, seed_len) }; let accounts = if accounts >= 0 { accounts as u32 @@ -176,7 +180,7 @@ pub extern "C" fn zcashlc_init_accounts_table( .collect(); let extfvks: Vec<_> = extsks.iter().map(ExtendedFullViewingKey::from).collect(); - init_accounts_table(&db_data, &NETWORK, &extfvks) + init_accounts_table(&db_data, &extfvks) .map(|_| { // Return the ExtendedSpendingKeys for the created accounts. let mut v: Vec<_> = extsks @@ -209,7 +213,7 @@ pub extern "C" fn zcashlc_init_accounts_table_with_keys( extfvks_len: usize, ) -> bool { let res = catch_panic(|| { - let db_data = wallet_db(db_data, db_data_len)?; + let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let extfvks = unsafe { std::slice::from_raw_parts(extfvks, extfvks_len) .into_iter() @@ -220,7 +224,7 @@ pub extern "C" fn zcashlc_init_accounts_table_with_keys( .unwrap() ).collect::>() }; - match init_accounts_table(&db_data, &NETWORK, &extfvks) { + match init_accounts_table(&db_data, &extfvks) { Ok(()) => Ok(true), Err(e) => Err(format_err!("Error while initializing accounts: {}", e)), } @@ -418,7 +422,7 @@ pub extern "C" fn zcashlc_init_blocks_table( sapling_tree_hex: *const c_char, ) -> i32 { let res = catch_panic(|| { - let db_data = wallet_db(db_data, db_data_len)?; + let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let hash = { let mut hash = hex::decode(unsafe { CStr::from_ptr(hash_hex) }.to_str()?).unwrap(); hash.reverse(); @@ -445,7 +449,7 @@ pub extern "C" fn zcashlc_get_address( account: i32, ) -> *mut c_char { let res = catch_panic(|| { - let db_data = wallet_db(db_data, db_data_len)?; + let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let account = if account >= 0 { account as u32 } else { @@ -454,7 +458,7 @@ pub extern "C" fn zcashlc_get_address( let account = AccountId(account); - match (&db_data).get_address(&NETWORK, account) { + match (&db_data).get_address(account) { Ok(Some(addr)) => { let addr_str = encode_payment_address(NETWORK.hrp_sapling_payment_address(), &addr); let c_str_addr = CString::new(addr_str).unwrap(); @@ -532,7 +536,7 @@ fn is_valid_transparent_address(address: &str) -> bool { #[no_mangle] pub extern "C" fn zcashlc_get_balance(db_data: *const u8, db_data_len: usize, account: i32) -> i64 { let res = catch_panic(|| { - let db_data = wallet_db(db_data, db_data_len)?; + let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let account = if account >= 0 { account as u32 @@ -557,7 +561,7 @@ pub extern "C" fn zcashlc_get_verified_balance( account: i32, ) -> i64 { let res = catch_panic(|| { - let db_data = wallet_db(db_data, db_data_len)?; + let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let account = if account >= 0 { account as u32 } else { @@ -574,7 +578,7 @@ pub extern "C" fn zcashlc_get_verified_balance( }) .and_then(|anchor| { (&db_data) - .get_verified_balance(account, anchor) + .get_balance_at(account, anchor) .map_err(|e| format_err!("Error while fetching verified balance: {}", e)) }) .map(|amount| amount.into()) @@ -1046,9 +1050,9 @@ pub fn double_sha256(payload: &[u8]) -> Vec { /// /// -fn shield_funds( +fn shield_funds( db_cache: &BlockDB, - db_data: &WalletDB, + db_data: &WalletDB

, account: u32, tsk: &str, extsk: &str, @@ -1066,7 +1070,6 @@ fn shield_funds( }, }; - // grab secret private key for t-funds let sk = match secp256k1::key::SecretKey::from_str(&tsk) { Ok(sk) => sk, @@ -1085,7 +1088,6 @@ fn shield_funds( }, }; - let extsk = match decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &extsk) { @@ -1112,30 +1114,18 @@ fn shield_funds( } }; + // get latest height and anchor + + let latest_scanned_height = anchor_and_height.1; + let latest_anchor = anchor_and_height.0; + // get UTXOs from DB - let utxos = match get_utxos(&NETWORK, &db_cache) { + let utxos = match get_confirmed_utxos_for_address(&NETWORK, &db_cache, latest_anchor, &t_addr_str) { Ok(u) => u, Err(e) => { return Err(format_err!("Error getting UTXOs {}",e)); }, }; - - // verify that the addresses of the UTXOs are correspond to the given t-address - let distinct_addresses: Vec<&UnspentTransactionOutput> = utxos.iter().filter(|utxo| utxo.address != t_addr).collect::>(); - - if distinct_addresses.len() > 0 { - return Err(format_err!("one or more UTXOs correspond to other addresses that don't match the provided SecretKey")); - } - - // check that the utxos are confirmed - - let latest_scanned_height = anchor_and_height.1; - let latest_anchor = anchor_and_height.0; - let unconfirmed_funds: Vec = utxos.iter().map(|u| u.height).filter(|h| h > &latest_anchor).collect(); - - if unconfirmed_funds.len() > 0 { - return Err(format_err!("one or more UTXOs are unconfirmed ")); - } let total_amount = match Amount::from_i64(utxos.iter().map(|u| i64::from(u.value)).sum::()) { Ok(a) => a, @@ -1229,7 +1219,6 @@ fn shield_funds( } up.insert_sent_note( - &NETWORK, tx_ref, output_index as usize, AccountId(account), @@ -1260,7 +1249,7 @@ pub extern "C" fn zcashlc_shield_funds( ) -> i64 { let res = catch_panic(|| { - let db_data = wallet_db(db_data, db_data_len)?; + let db_data = wallet_db::(db_data, db_data_len)?; let db_cache = block_db(db_cache, db_cache_len)?; let account = if account >= 0 { account as u32