diff --git a/src/main/rust/sql.rs b/src/main/rust/sql.rs index 0a57ba7f..cfac0864 100644 --- a/src/main/rust/sql.rs +++ b/src/main/rust/sql.rs @@ -36,6 +36,33 @@ fn address_from_extfvk(extfvk: &ExtendedFullViewingKey) -> String { encode_payment_address(HRP_SAPLING_PAYMENT_ADDRESS_TEST, &addr) } +/// Determine the target height for a transaction, and the height from which to +/// select anchors, based on the current synchronised block chain. +fn get_target_and_anchor_heights(data: &Connection) -> Result<(u32, u32), Error> { + data.query_row_and_then( + "SELECT MIN(height), MAX(height) FROM blocks", + NO_PARAMS, + |row| match (row.get_checked::<_, u32>(0), row.get_checked::<_, u32>(1)) { + // If there are no blocks, the query returns NULL. + (Err(rusqlite::Error::InvalidColumnType(_, _)), _) + | (_, Err(rusqlite::Error::InvalidColumnType(_, _))) => { + Err(format_err!("Must scan blocks first")) + } + (Err(e), _) | (_, Err(e)) => Err(e.into()), + (Ok(min_height), Ok(max_height)) => { + let target_height = max_height + 1; + + // Select an anchor ANCHOR_OFFSET back from the target block, + // unless that would be before the earliest block we have. + let anchor_height = + cmp::max(target_height.saturating_sub(ANCHOR_OFFSET), min_height); + + Ok((target_height, anchor_height)) + } + }, + ) +} + pub fn init_cache_database>(db_cache: P) -> Result<(), Error> { let cache = Connection::open(db_cache)?; cache.execute( @@ -220,6 +247,24 @@ pub fn get_balance>(db_data: P, account: u32) -> Result>(db_data: P, account: u32) -> Result { + let data = Connection::open(db_data)?; + + let (_, anchor_height) = get_target_and_anchor_heights(&data)?; + + let balance = data.query_row( + "SELECT SUM(value) FROM received_notes + INNER JOIN transactions ON transactions.id_tx = received_notes.tx + WHERE account = ? AND spent IS NULL AND transactions.block <= ?", + &[account, anchor_height], + |row| row.get_checked(0).unwrap_or(0), + )?; + + Ok(Amount(balance)) +} + pub fn get_received_memo_as_utf8>( db_data: P, id_note: i64, @@ -571,30 +616,10 @@ pub fn send_to_address>( let ovk = extfvk.fvk.ovk; // Target the next block, assuming we are up-to-date. - let (height, anchor_height) = data.query_row_and_then( - "SELECT MIN(height), MAX(height) FROM blocks", - NO_PARAMS, - |row| match (row.get_checked::<_, u32>(0), row.get_checked::<_, u32>(1)) { - // If there are no blocks, the query returns NULL. - (Err(rusqlite::Error::InvalidColumnType(_, _)), _) - | (_, Err(rusqlite::Error::InvalidColumnType(_, _))) => { - Err(format_err!("Must sync before calling send_to_address()")) - } - (Err(e), _) | (_, Err(e)) => Err(e.into()), - (Ok(min_height), Ok(max_height)) => { - let target_height = max_height + 1; - - // Select an anchor ANCHOR_OFFSET back from the target block, - // unless that would be before the earliest block we have. - let anchor_height = i64::from(cmp::max( - target_height.saturating_sub(ANCHOR_OFFSET), - min_height, - )); - - Ok((target_height, anchor_height)) - } - }, - )?; + let (height, anchor_height) = { + let (target_height, anchor_height) = get_target_and_anchor_heights(&data)?; + (target_height, i64::from(anchor_height)) + }; // The goal of this SQL statement is to select the oldest notes until the required // value has been reached, and then fetch the witnesses at the desired height for the @@ -806,8 +831,8 @@ mod tests { use zip32::{ExtendedFullViewingKey, ExtendedSpendingKey}; use super::{ - get_address, get_balance, init_accounts_table, init_blocks_table, init_cache_database, - init_data_database, scan_cached_blocks, send_to_address, + get_address, get_balance, get_verified_balance, init_accounts_table, init_blocks_table, + init_cache_database, init_data_database, scan_cached_blocks, send_to_address, }; fn test_prover() -> impl TxProver { @@ -1204,7 +1229,7 @@ mod tests { // We cannot do anything if we aren't synchronised match send_to_address(db_data, 1, test_prover(), (0, &extsk), &to, Amount(1), None) { Ok(_) => panic!("Should have failed"), - Err(e) => assert_eq!(e.to_string(), "Must sync before calling send_to_address()"), + Err(e) => assert_eq!(e.to_string(), "Must scan blocks first"), } } @@ -1254,13 +1279,19 @@ mod tests { let (cb, _) = fake_compact_block(1, extfvk.clone(), value); insert_into_cache(db_cache, &cb); scan_cached_blocks(db_cache, db_data).unwrap(); + + // Verified balance matches total balance assert_eq!(get_balance(db_data, 0).unwrap(), value); + assert_eq!(get_verified_balance(db_data, 0).unwrap(), value); // Add more funds to the wallet in a second note let (cb, _) = fake_compact_block(2, extfvk.clone(), value); insert_into_cache(db_cache, &cb); scan_cached_blocks(db_cache, db_data).unwrap(); + + // Verified balance does not include the second note assert_eq!(get_balance(db_data, 0).unwrap().0, 2 * value.0); + assert_eq!(get_verified_balance(db_data, 0).unwrap(), value); // Spend fails because there are insufficient verified notes let extsk2 = ExtendedSpendingKey::master(&[]);