diff --git a/zcash_client_backend/CHANGELOG.md b/zcash_client_backend/CHANGELOG.md index 6ffac75bb..f04a1640f 100644 --- a/zcash_client_backend/CHANGELOG.md +++ b/zcash_client_backend/CHANGELOG.md @@ -88,7 +88,8 @@ and this library adheres to Rust's notion of - The `zcash_client_backend::data_api::SentTransaction` type has been substantially modified to accommodate handling of transparent inputs. Per-output data has been split out into a new struct `SentTransactionOutput` - and `SentTransaction` can now contain multiple outputs. + and `SentTransaction` can now contain multiple outputs, and tracks the + fee paid. - `data_api::WalletWrite::store_received_tx` has been renamed to `store_decrypted_tx`. - `data_api::ReceivedTransaction` has been renamed to `DecryptedTransaction`, diff --git a/zcash_client_backend/src/data_api.rs b/zcash_client_backend/src/data_api.rs index aafbebbcc..59ddb37fa 100644 --- a/zcash_client_backend/src/data_api.rs +++ b/zcash_client_backend/src/data_api.rs @@ -245,6 +245,7 @@ pub struct SentTransaction<'a> { pub created: time::OffsetDateTime, pub account: AccountId, pub outputs: Vec>, + pub fee_amount: Amount, #[cfg(feature = "transparent-inputs")] pub utxos_spent: Vec, } diff --git a/zcash_client_backend/src/data_api/wallet.rs b/zcash_client_backend/src/data_api/wallet.rs index e833a57ed..42128d3c3 100644 --- a/zcash_client_backend/src/data_api/wallet.rs +++ b/zcash_client_backend/src/data_api/wallet.rs @@ -316,7 +316,7 @@ where } // Create the transaction - let mut builder = Builder::new(params.clone(), height); + let mut builder = Builder::new_with_fee(params.clone(), height, DEFAULT_FEE); for selected in spendable_notes { let from = extfvk .fvk @@ -400,8 +400,9 @@ where wallet_db.store_sent_tx(&SentTransaction { tx: &tx, created: time::OffsetDateTime::now_utc(), - outputs: sent_outputs, account, + outputs: sent_outputs, + fee_amount: DEFAULT_FEE, #[cfg(feature = "transparent-inputs")] utxos_spent: vec![], }) @@ -494,7 +495,7 @@ where let amount_to_shield = (total_amount - fee).ok_or_else(|| E::from(Error::InvalidAmount))?; - let mut builder = Builder::new(params.clone(), latest_scanned_height); + let mut builder = Builder::new_with_fee(params.clone(), latest_scanned_height, fee); let secret_key = sk.derive_external_secret_key(child_index).unwrap(); for utxo in &utxos { @@ -531,6 +532,7 @@ where value: amount_to_shield, memo: Some(memo.clone()), }], + fee_amount: fee, utxos_spent: utxos.iter().map(|utxo| utxo.outpoint.clone()).collect(), }) } diff --git a/zcash_client_sqlite/CHANGELOG.md b/zcash_client_sqlite/CHANGELOG.md index 4b153449f..07b2d50fc 100644 --- a/zcash_client_sqlite/CHANGELOG.md +++ b/zcash_client_sqlite/CHANGELOG.md @@ -18,6 +18,14 @@ and this library adheres to Rust's notion of rewinds exceed supported bounds. - An `unstable` feature flag; this is added to parts of the API that may change in any release. It enables `zcash_client_backend`'s `unstable` feature flag. +- New summary views that may be directly accessed in the sqlite database. + The structure of these views should be considered unstable; they may + be replaced by accessors provided by the data access API at some point + in the future: + - `v_transactions` + - `v_tx_received` + - `v_tx_sent` +- `zcash_client_sqlite::wallet::init::WalletMigrationError` ### Changed - Various **BREAKING CHANGES** have been made to the database tables. These will diff --git a/zcash_client_sqlite/Cargo.toml b/zcash_client_sqlite/Cargo.toml index 2281efd76..7d3d6b8ad 100644 --- a/zcash_client_sqlite/Cargo.toml +++ b/zcash_client_sqlite/Cargo.toml @@ -33,12 +33,18 @@ zcash_client_backend = { version = "0.5", path = "../zcash_client_backend" } zcash_primitives = { version = "0.7", path = "../zcash_primitives" } [dev-dependencies] +proptest = "1.0.0" +regex = "1.4" tempfile = "3" zcash_proofs = { version = "0.7", path = "../zcash_proofs" } +zcash_primitives = { version = "0.7", path = "../zcash_primitives", features = ["test-dependencies"] } [features] mainnet = [] -test-dependencies = ["zcash_client_backend/test-dependencies"] +test-dependencies = [ + "zcash_primitives/test-dependencies", + "zcash_client_backend/test-dependencies", +] transparent-inputs = ["hdwallet", "zcash_client_backend/transparent-inputs"] unstable = ["zcash_client_backend/unstable"] diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index fece9fd38..39f231695 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -456,7 +456,7 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { ) -> Result { let nullifiers = self.wallet_db.get_all_nullifiers()?; self.transactionally(|up| { - let tx_ref = wallet::put_tx_data(up, d_tx.tx, None)?; + let tx_ref = wallet::put_tx_data(up, d_tx.tx, None, None)?; let mut spending_account_id: Option = None; for output in d_tx.sapling_outputs { @@ -515,7 +515,12 @@ impl<'a, P: consensus::Parameters> WalletWrite for DataConnStmtCache<'a, P> { fn store_sent_tx(&mut self, sent_tx: &SentTransaction) -> Result { // Update the database atomically, to ensure the result is internally consistent. self.transactionally(|up| { - let tx_ref = wallet::put_tx_data(up, sent_tx.tx, Some(sent_tx.created))?; + let tx_ref = wallet::put_tx_data( + up, + sent_tx.tx, + Some(sent_tx.fee_amount), + Some(sent_tx.created), + )?; // Mark notes as spent. // diff --git a/zcash_client_sqlite/src/prepared.rs b/zcash_client_sqlite/src/prepared.rs index ec8b4b6cf..f80114ae5 100644 --- a/zcash_client_sqlite/src/prepared.rs +++ b/zcash_client_sqlite/src/prepared.rs @@ -84,12 +84,15 @@ impl<'a, P> DataConnStmtCache<'a, P> { SET block = ?, tx_index = ? WHERE txid = ?", )?, stmt_insert_tx_data: wallet_db.conn.prepare( - "INSERT INTO transactions (txid, created, expiry_height, raw) - VALUES (?, ?, ?, ?)", + "INSERT INTO transactions (txid, created, expiry_height, raw, fee) + VALUES (?, ?, ?, ?, ?)", )?, stmt_update_tx_data: wallet_db.conn.prepare( "UPDATE transactions - SET expiry_height = ?, raw = ? WHERE txid = ?", + SET expiry_height = :expiry_height, + raw = :raw, + fee = IFNULL(:fee, fee) + WHERE txid = :txid", )?, stmt_select_tx_ref: wallet_db.conn.prepare( "SELECT id_tx FROM transactions WHERE txid = ?", @@ -132,12 +135,17 @@ impl<'a, P> DataConnStmtCache<'a, P> { )?, stmt_update_sent_note: wallet_db.conn.prepare( "UPDATE sent_notes - SET from_account = ?, address = ?, value = ?, memo = ? - WHERE tx = ? AND output_pool = ? AND output_index = ?", + SET from_account = :account, + address = :address, + value = :value, + memo = IFNULL(:memo, memo) + WHERE tx = :tx + AND output_pool = :output_pool + AND output_index = :output_index", )?, stmt_insert_sent_note: wallet_db.conn.prepare( "INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value, memo) - VALUES (?, ?, ?, ?, ?, ?, ?)", + VALUES (:tx, :output_pool, :output_index, :from_account, :address, :value, :memo)" )?, stmt_insert_witness: wallet_db.conn.prepare( "INSERT INTO sapling_witnesses (note, block, witness) @@ -226,12 +234,14 @@ impl<'a, P> DataConnStmtCache<'a, P> { created_at: Option, expiry_height: BlockHeight, raw_tx: &[u8], + fee: Option, ) -> Result { self.stmt_insert_tx_data.execute(params![ &txid.as_ref()[..], created_at, u32::from(expiry_height), - raw_tx + raw_tx, + fee.map(i64::from) ])?; Ok(self.wallet_db.conn.last_insert_rowid()) @@ -244,13 +254,16 @@ impl<'a, P> DataConnStmtCache<'a, P> { &mut self, expiry_height: BlockHeight, raw_tx: &[u8], + fee: Option, txid: &TxId, ) -> Result { - match self.stmt_update_tx_data.execute(params![ - u32::from(expiry_height), - raw_tx, - &txid.as_ref()[..], - ])? { + let sql_args: &[(&str, &dyn ToSql)] = &[ + (":expiry_height", &u32::from(expiry_height)), + (":raw", &raw_tx), + (":fee", &fee.map(i64::from)), + (":txid", &&txid.as_ref()[..]), + ]; + match self.stmt_update_tx_data.execute_named(sql_args)? { 0 => Ok(false), 1 => Ok(true), _ => unreachable!("txid column is marked as UNIQUE"), @@ -388,7 +401,12 @@ impl<'a, P> DataConnStmtCache<'a, P> { (":value", &(value as i64)), (":rcm", &rcm.as_ref()), (":nf", &nf.as_ref().map(|nf| nf.0.as_ref())), - (":memo", &memo.map(|m| m.as_slice())), + ( + ":memo", + &memo + .filter(|m| *m != &MemoBytes::empty()) + .map(|m| m.as_slice()), + ), (":is_change", &is_change), ]; @@ -425,7 +443,12 @@ impl<'a, P> DataConnStmtCache<'a, P> { (":value", &(value as i64)), (":rcm", &rcm.as_ref()), (":nf", &nf.as_ref().map(|nf| nf.0.as_ref())), - (":memo", &memo.map(|m| m.as_slice())), + ( + ":memo", + &memo + .filter(|m| *m != &MemoBytes::empty()) + .map(|m| m.as_slice()), + ), (":is_change", &is_change), (":tx", &tx_ref), (":output_index", &(output_index as i64)), @@ -470,17 +493,21 @@ impl<'a, P> DataConnStmtCache<'a, P> { value: Amount, memo: Option<&MemoBytes>, ) -> Result<(), SqliteClientError> { - let ivalue: i64 = value.into(); - self.stmt_insert_sent_note.execute(params![ - tx_ref, - pool_type.typecode(), - (output_index as i64), - u32::from(account), - to_str, - ivalue, - memo.map(|m| m.as_slice()), - ])?; - + let sql_args: &[(&str, &dyn ToSql)] = &[ + (":tx", &tx_ref), + (":output_pool", &pool_type.typecode()), + (":output_index", &i64::try_from(output_index).unwrap()), + (":from_account", &u32::from(account)), + (":address", &to_str), + (":value", &i64::from(value)), + ( + ":memo", + &memo + .filter(|m| *m != &MemoBytes::empty()) + .map(|m| m.as_slice()), + ), + ]; + self.stmt_insert_sent_note.execute_named(sql_args)?; Ok(()) } @@ -498,16 +525,21 @@ impl<'a, P> DataConnStmtCache<'a, P> { pool_type: PoolType, output_index: usize, ) -> Result { - let ivalue: i64 = value.into(); - match self.stmt_update_sent_note.execute(params![ - u32::from(account), - to_str, - ivalue, - &memo.map(|m| m.as_slice()), - tx_ref, - pool_type.typecode(), - output_index as i64, - ])? { + let sql_args: &[(&str, &dyn ToSql)] = &[ + (":account", &u32::from(account)), + (":address", &to_str), + (":value", &i64::from(value)), + ( + ":memo", + &memo + .filter(|m| *m != &MemoBytes::empty()) + .map(|m| m.as_slice()), + ), + (":tx", &tx_ref), + (":output_pool", &pool_type.typecode()), + (":output_index", &i64::try_from(output_index).unwrap()), + ]; + match self.stmt_update_sent_note.execute_named(sql_args)? { 0 => Ok(false), 1 => Ok(true), _ => unreachable!("tx_output constraint is marked as UNIQUE"), diff --git a/zcash_client_sqlite/src/wallet.rs b/zcash_client_sqlite/src/wallet.rs index 709251dbc..eb009250a 100644 --- a/zcash_client_sqlite/src/wallet.rs +++ b/zcash_client_sqlite/src/wallet.rs @@ -936,6 +936,7 @@ pub fn put_tx_meta<'a, P, N>( pub fn put_tx_data<'a, P>( stmts: &mut DataConnStmtCache<'a, P>, tx: &Transaction, + fee: Option, created_at: Option, ) -> Result { let txid = tx.txid(); @@ -943,9 +944,9 @@ pub fn put_tx_data<'a, P>( let mut raw_tx = vec![]; tx.write(&mut raw_tx)?; - if !stmts.stmt_update_tx_data(tx.expiry_height(), &raw_tx, &txid)? { + if !stmts.stmt_update_tx_data(tx.expiry_height(), &raw_tx, fee, &txid)? { // It isn't there, so insert our transaction into the database. - stmts.stmt_insert_tx_data(&txid, created_at, tx.expiry_height(), &raw_tx) + stmts.stmt_insert_tx_data(&txid, created_at, tx.expiry_height(), &raw_tx, fee) } else { // It was there, so grab its row number. stmts.stmt_select_tx_ref(&txid) diff --git a/zcash_client_sqlite/src/wallet/init.rs b/zcash_client_sqlite/src/wallet/init.rs index f0de43e69..9613d12fc 100644 --- a/zcash_client_sqlite/src/wallet/init.rs +++ b/zcash_client_sqlite/src/wallet/init.rs @@ -1,5 +1,5 @@ //! Functions for initializing the various databases. -use rusqlite::{self, params, types::ToSql, Connection, Transaction, NO_PARAMS}; +use rusqlite::{self, params, types::ToSql, Connection, NO_PARAMS}; use schemer::{migration, Migration, Migrator, MigratorError}; use schemer_rusqlite::{RusqliteAdapter, RusqliteMigration}; use secrecy::{ExposeSecret, SecretVec}; @@ -10,6 +10,7 @@ use uuid::Uuid; use zcash_primitives::{ block::BlockHash, consensus::{self, BlockHeight}, + transaction::components::amount::BalanceError, zip32::AccountId, }; @@ -38,6 +39,9 @@ pub enum WalletMigrationError { /// Wrapper for rusqlite errors. DbError(rusqlite::Error), + + /// Wrapper for amount balance violations + BalanceError(BalanceError), } impl From for WalletMigrationError { @@ -46,6 +50,12 @@ impl From for WalletMigrationError { } } +impl From for WalletMigrationError { + fn from(e: BalanceError) -> Self { + WalletMigrationError::BalanceError(e) + } +} + impl fmt::Display for WalletMigrationError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self { @@ -59,6 +69,7 @@ impl fmt::Display for WalletMigrationError { write!(f, "Wallet database is corrupted: {}", reason) } WalletMigrationError::DbError(e) => write!(f, "{}", e), + WalletMigrationError::BalanceError(e) => write!(f, "Balance error: {:?}", e), } } } @@ -84,7 +95,7 @@ migration!( impl RusqliteMigration for WalletMigration0 { type Error = WalletMigrationError; - fn up(&self, transaction: &Transaction) -> Result<(), WalletMigrationError> { + fn up(&self, transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { transaction.execute_batch( // We set the user_version field of the database to a constant value of 8 to allow // correct integration with the Android SDK with versions of the database that were @@ -155,7 +166,7 @@ impl RusqliteMigration for WalletMigration0 { Ok(()) } - fn down(&self, _transaction: &Transaction) -> Result<(), WalletMigrationError> { + fn down(&self, _transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { // We should never down-migrate the first migration, as that can irreversibly // destroy data. panic!("Cannot revert the initial migration."); @@ -174,7 +185,7 @@ migration!( impl RusqliteMigration for WalletMigration1 { type Error = WalletMigrationError; - fn up(&self, transaction: &Transaction) -> Result<(), WalletMigrationError> { + fn up(&self, transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { transaction.execute_batch( "CREATE TABLE IF NOT EXISTS utxos ( id_utxo INTEGER PRIMARY KEY, @@ -192,20 +203,26 @@ impl RusqliteMigration for WalletMigration1 { Ok(()) } - fn down(&self, transaction: &Transaction) -> Result<(), WalletMigrationError> { + fn down(&self, transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { transaction.execute_batch("DROP TABLE utxos;")?; Ok(()) } } -struct WalletMigration2 { +struct WalletMigration2

{ params: P, seed: Option>, } -impl Migration for WalletMigration2

{ +impl

WalletMigration2

{ + fn id() -> Uuid { + Uuid::parse_str("be57ef3b-388e-42ea-97e2-678dafcf9754").unwrap() + } +} + +impl

Migration for WalletMigration2

{ fn id(&self) -> Uuid { - ::uuid::Uuid::parse_str("be57ef3b-388e-42ea-97e2-678dafcf9754").unwrap() + WalletMigration2::

::id() } fn dependencies(&self) -> HashSet { @@ -223,7 +240,7 @@ impl Migration for WalletMigration2

{ impl RusqliteMigration for WalletMigration2

{ type Error = WalletMigrationError; - fn up(&self, transaction: &Transaction) -> Result<(), WalletMigrationError> { + fn up(&self, transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { // // Update the accounts table to store ufvks rather than extfvks // @@ -377,7 +394,7 @@ impl RusqliteMigration for WalletMigration2

{ )?; let mut stmt_insert_sent_note = transaction.prepare( - "INSERT INTO sent_notes_new + "INSERT INTO sent_notes_new (id_note, tx, output_pool, output_index, from_account, address, value, memo) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", )?; @@ -429,7 +446,7 @@ impl RusqliteMigration for WalletMigration2

{ Ok(()) } - fn down(&self, _transaction: &Transaction) -> Result<(), WalletMigrationError> { + fn down(&self, _transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { // TODO: something better than just panic? panic!("Cannot revert this migration."); } @@ -473,6 +490,14 @@ impl RusqliteMigration for WalletMigration2

{ pub fn init_wallet_db( wdb: &mut WalletDb

, seed: Option>, +) -> Result<(), MigratorError> { + init_wallet_db_internal(wdb, seed, None) +} + +fn init_wallet_db_internal( + wdb: &mut WalletDb

, + seed: Option>, + target_migration: Option, ) -> Result<(), MigratorError> { wdb.conn .execute("PRAGMA foreign_keys = OFF", NO_PARAMS) @@ -481,20 +506,10 @@ pub fn init_wallet_db( adapter.init().expect("Migrations table setup succeeds."); let mut migrator = Migrator::new(adapter); - let migration0 = Box::new(WalletMigration0 {}); - let migration1 = Box::new(WalletMigration1 {}); - let migration2 = Box::new(WalletMigration2 { - params: wdb.params.clone(), - seed, - }); - let addrs_migration = Box::new(migrations::AddressesTableMigration { - params: wdb.params.clone(), - }); - migrator - .register_multiple(vec![migration0, migration1, migration2, addrs_migration]) + .register_multiple(migrations::all_migrations(&wdb.params, seed)) .expect("Wallet migration registration should have been successful."); - migrator.up(None)?; + migrator.up(target_migration)?; wdb.conn .execute("PRAGMA foreign_keys = ON", NO_PARAMS) .map_err(|e| MigratorError::Adapter(WalletMigrationError::from(e)))?; @@ -682,8 +697,9 @@ mod tests { use zcash_primitives::{ block::BlockHash, - consensus::{BlockHeight, Parameters}, + consensus::{BlockHeight, BranchId, Parameters}, sapling::keys::DiversifiableFullViewingKey, + transaction::{TransactionData, TxVersion}, zip32::ExtendedFullViewingKey, }; @@ -704,12 +720,10 @@ mod tests { let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); init_wallet_db(&mut db_data, None).unwrap(); - let mut stmt_schema_sql = db_data - .conn - .prepare("SELECT sql FROM sqlite_schema WHERE type = 'table' ORDER BY tbl_name") - .unwrap(); - let mut rows = stmt_schema_sql.query(NO_PARAMS).unwrap(); - let expected = vec![ + use regex::Regex; + let re = Regex::new(r"\s+").unwrap(); + + let expected_tables = vec![ "CREATE TABLE \"accounts\" ( account INTEGER PRIMARY KEY, ufvk TEXT NOT NULL @@ -754,8 +768,8 @@ mod tests { CONSTRAINT witness_height UNIQUE (note, block) )", "CREATE TABLE schemer_migrations ( - id blob PRIMARY KEY - )", + id blob PRIMARY KEY + )", "CREATE TABLE \"sent_notes\" ( id_note INTEGER PRIMARY KEY, tx INTEGER NOT NULL, @@ -777,6 +791,7 @@ mod tests { tx_index INTEGER, expiry_height INTEGER, raw BLOB, + fee INTEGER, FOREIGN KEY (block) REFERENCES blocks(height) )", "CREATE TABLE utxos ( @@ -793,10 +808,125 @@ mod tests { )", ]; + let mut tables_query = db_data + .conn + .prepare("SELECT sql FROM sqlite_schema WHERE type = 'table' ORDER BY tbl_name") + .unwrap(); + let mut rows = tables_query.query(NO_PARAMS).unwrap(); let mut expected_idx = 0; while let Some(row) = rows.next().unwrap() { let sql: String = row.get(0).unwrap(); - assert_eq!(&sql, expected[expected_idx]); + assert_eq!( + re.replace_all(&sql, " "), + re.replace_all(expected_tables[expected_idx], " ") + ); + expected_idx += 1; + } + + let expected_views = vec![ + "CREATE VIEW v_transactions AS + SELECT id_tx, + mined_height, + tx_index, + txid, + expiry_height, + raw, + SUM(value) + MAX(fee) AS net_value, + SUM(is_change) > 0 AS has_change, + SUM(memo_present) AS memo_count + FROM ( + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + transactions.expiry_height AS expiry_height, + transactions.raw AS raw, + 0 AS fee, + CASE + WHEN received_notes.is_change THEN 0 + ELSE value + END AS value, + received_notes.is_change AS is_change, + CASE + WHEN received_notes.memo IS NULL THEN 0 + ELSE 1 + END AS memo_present + FROM transactions + JOIN received_notes ON transactions.id_tx = received_notes.tx + UNION + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + transactions.expiry_height AS expiry_height, + transactions.raw AS raw, + transactions.fee AS fee, + -sent_notes.value AS value, + false AS is_change, + CASE + WHEN sent_notes.memo IS NULL THEN 0 + ELSE 1 + END AS memo_present + FROM transactions + JOIN sent_notes ON transactions.id_tx = sent_notes.tx + ) + GROUP BY id_tx", + "CREATE VIEW v_tx_received AS + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + SUM(received_notes.value) AS received_total, + COUNT(received_notes.id_note) AS received_note_count, + SUM( + CASE + WHEN received_notes.memo IS NULL THEN 0 + ELSE 1 + END + ) AS memo_count, + blocks.time AS block_time + FROM transactions + JOIN received_notes + ON transactions.id_tx = received_notes.tx + LEFT JOIN blocks + ON transactions.block = blocks.height + GROUP BY received_notes.tx", + "CREATE VIEW v_tx_sent AS + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + transactions.expiry_height AS expiry_height, + transactions.raw AS raw, + SUM(sent_notes.value) AS sent_total, + COUNT(sent_notes.id_note) AS sent_note_count, + SUM( + CASE + WHEN sent_notes.memo IS NULL THEN 0 + ELSE 1 + END + ) AS memo_count, + blocks.time AS block_time + FROM transactions + JOIN sent_notes + ON transactions.id_tx = sent_notes.tx + LEFT JOIN blocks + ON transactions.block = blocks.height + GROUP BY sent_notes.tx", + ]; + + let mut views_query = db_data + .conn + .prepare("SELECT sql FROM sqlite_schema WHERE type = 'view' ORDER BY tbl_name") + .unwrap(); + let mut rows = views_query.query(NO_PARAMS).unwrap(); + let mut expected_idx = 0; + while let Some(row) = rows.next().unwrap() { + let sql: String = row.get(0).unwrap(); + assert_eq!( + re.replace_all(&sql, " "), + re.replace_all(expected_views[expected_idx], " ") + ); expected_idx += 1; } } @@ -1041,9 +1171,25 @@ mod tests { "INSERT INTO blocks (height, hash, time, sapling_tree) VALUES (0, 0, 0, '')", NO_PARAMS, )?; + + let tx = TransactionData::from_parts( + TxVersion::Sapling, + BranchId::Canopy, + 0, + BlockHeight::from(0), + None, + None, + None, + None, + ) + .freeze() + .unwrap(); + + let mut tx_bytes = vec![]; + tx.write(&mut tx_bytes).unwrap(); wdb.conn.execute( - "INSERT INTO transactions (block, id_tx, txid) VALUES (0, 0, '')", - NO_PARAMS, + "INSERT INTO transactions (block, id_tx, txid, raw) VALUES (0, 0, '', ?)", + &[&tx_bytes[..]], )?; wdb.conn.execute( "INSERT INTO sent_notes (tx, output_index, from_account, address, value) diff --git a/zcash_client_sqlite/src/wallet/init/migrations.rs b/zcash_client_sqlite/src/wallet/init/migrations.rs index deea2236c..ab56dc203 100644 --- a/zcash_client_sqlite/src/wallet/init/migrations.rs +++ b/zcash_client_sqlite/src/wallet/init/migrations.rs @@ -1,2 +1,30 @@ mod addresses_table; pub(super) use addresses_table::AddressesTableMigration; + +mod add_transaction_views; + +use schemer_rusqlite::RusqliteMigration; +use secrecy::SecretVec; +use zcash_primitives::consensus; + +use super::{WalletMigration0, WalletMigration1, WalletMigration2, WalletMigrationError}; + +pub(super) fn all_migrations( + params: &P, + seed: Option>, +) -> Vec>> { + vec![ + Box::new(WalletMigration0 {}), + Box::new(WalletMigration1 {}), + Box::new(WalletMigration2 { + params: params.clone(), + seed, + }), + Box::new(AddressesTableMigration { + params: params.clone(), + }), + Box::new(add_transaction_views::Migration { + params: params.clone(), + }), + ] +} diff --git a/zcash_client_sqlite/src/wallet/init/migrations/add_transaction_views.rs b/zcash_client_sqlite/src/wallet/init/migrations/add_transaction_views.rs new file mode 100644 index 000000000..3cefb8645 --- /dev/null +++ b/zcash_client_sqlite/src/wallet/init/migrations/add_transaction_views.rs @@ -0,0 +1,411 @@ +//! Functions for initializing the various databases. +use rusqlite::{self, types::ToSql, OptionalExtension, NO_PARAMS}; +use schemer::{self}; +use schemer_rusqlite::RusqliteMigration; + +use std::collections::HashSet; + +use uuid::Uuid; + +use zcash_primitives::{ + consensus::{self, BlockHeight, BranchId}, + transaction::{components::amount::Amount, Transaction}, +}; + +use super::super::{WalletMigration2, WalletMigrationError}; + +pub(crate) fn migration_id() -> Uuid { + Uuid::parse_str("282fad2e-8372-4ca0-8bed-71821320909f").unwrap() +} + +pub(crate) struct Migration

{ + pub(super) params: P, +} + +impl

schemer::Migration for Migration

{ + fn id(&self) -> Uuid { + migration_id() + } + + fn dependencies(&self) -> HashSet { + let mut deps = HashSet::new(); + deps.insert(WalletMigration2::

::id()); + deps + } + + fn description(&self) -> &'static str { + "Add transaction summary views & add fee information to transactions." + } +} + +impl RusqliteMigration for Migration

{ + type Error = WalletMigrationError; + + fn up(&self, transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { + transaction.execute_batch("ALTER TABLE transactions ADD COLUMN fee INTEGER;")?; + + let mut stmt_list_txs = + transaction.prepare("SELECT id_tx, raw, block FROM transactions")?; + + let mut stmt_set_fee = + transaction.prepare("UPDATE transactions SET fee = ? WHERE id_tx = ?")?; + + let mut stmt_find_utxo_value = transaction + .prepare("SELECT value_zat FROM utxos WHERE prevout_txid = ? AND prevout_idx = ?")?; + + let mut tx_rows = stmt_list_txs.query(NO_PARAMS)?; + while let Some(row) = tx_rows.next()? { + let id_tx: i64 = row.get(0)?; + let tx_bytes: Option> = row.get(1)?; + let h: u32 = row.get(2)?; + let block_height = BlockHeight::from(h); + + // If only transaction metadata has been stored, and not transaction data, the fee + // information will eventually be set when the full transaction data is inserted. + if let Some(b) = tx_bytes { + let tx = + Transaction::read(&b[..], BranchId::for_height(&self.params, block_height)) + .map_err(|e| { + WalletMigrationError::CorruptedData(format!( + "Parsing failed for transaction {:?}: {:?}", + id_tx, e + )) + })?; + + let fee_paid = tx.fee_paid(|op| { + let op_amount = stmt_find_utxo_value + .query_row(&[op.hash().to_sql()?, op.n().to_sql()?], |row| { + row.get::<_, i64>(0) + }) + .optional() + .map_err(WalletMigrationError::DbError)?; + + op_amount.map_or_else( + || { + Err(WalletMigrationError::CorruptedData(format!( + "Unable to find UTXO corresponding to outpoint {:?}", + op + ))) + }, + |i| { + Amount::from_i64(i).map_err(|_| { + WalletMigrationError::CorruptedData(format!( + "UTXO amount out of range in outpoint {:?}", + op + )) + }) + }, + ) + })?; + + stmt_set_fee.execute(&[i64::from(fee_paid), id_tx])?; + } + } + + transaction.execute_batch( + "UPDATE sent_notes SET memo = NULL + WHERE memo = X'F600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'; + UPDATE received_notes SET memo = NULL + WHERE memo = X'F600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000';")?; + + transaction.execute_batch( + "CREATE VIEW v_tx_sent AS + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + transactions.expiry_height AS expiry_height, + transactions.raw AS raw, + SUM(sent_notes.value) AS sent_total, + COUNT(sent_notes.id_note) AS sent_note_count, + SUM( + CASE + WHEN sent_notes.memo IS NULL THEN 0 + ELSE 1 + END + ) AS memo_count, + blocks.time AS block_time + FROM transactions + JOIN sent_notes + ON transactions.id_tx = sent_notes.tx + LEFT JOIN blocks + ON transactions.block = blocks.height + GROUP BY sent_notes.tx; + CREATE VIEW v_tx_received AS + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + SUM(received_notes.value) AS received_total, + COUNT(received_notes.id_note) AS received_note_count, + SUM( + CASE + WHEN received_notes.memo IS NULL THEN 0 + ELSE 1 + END + ) AS memo_count, + blocks.time AS block_time + FROM transactions + JOIN received_notes + ON transactions.id_tx = received_notes.tx + LEFT JOIN blocks + ON transactions.block = blocks.height + GROUP BY received_notes.tx; + CREATE VIEW v_transactions AS + SELECT id_tx, + mined_height, + tx_index, + txid, + expiry_height, + raw, + SUM(value) + MAX(fee) AS net_value, + SUM(is_change) > 0 AS has_change, + SUM(memo_present) AS memo_count + FROM ( + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + transactions.expiry_height AS expiry_height, + transactions.raw AS raw, + 0 AS fee, + CASE + WHEN received_notes.is_change THEN 0 + ELSE value + END AS value, + received_notes.is_change AS is_change, + CASE + WHEN received_notes.memo IS NULL THEN 0 + ELSE 1 + END AS memo_present + FROM transactions + JOIN received_notes ON transactions.id_tx = received_notes.tx + UNION + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + transactions.expiry_height AS expiry_height, + transactions.raw AS raw, + transactions.fee AS fee, + -sent_notes.value AS value, + false AS is_change, + CASE + WHEN sent_notes.memo IS NULL THEN 0 + ELSE 1 + END AS memo_present + FROM transactions + JOIN sent_notes ON transactions.id_tx = sent_notes.tx + ) + GROUP BY id_tx;", + )?; + + Ok(()) + } + + fn down(&self, _transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { + // TODO: something better than just panic? + panic!("Cannot revert this migration."); + } +} + +#[cfg(test)] +mod tests { + use rusqlite::{self, NO_PARAMS}; + use tempfile::NamedTempFile; + + #[cfg(feature = "transparent-inputs")] + use { + crate::wallet::init::WalletMigration2, + rusqlite::params, + zcash_client_backend::{encoding::AddressCodec, keys::UnifiedSpendingKey}, + zcash_primitives::{ + consensus::{BlockHeight, BranchId, Network}, + legacy::{keys::IncomingViewingKey, Script}, + transaction::{ + components::{ + transparent::{self, Authorized, OutPoint}, + Amount, TxIn, TxOut, + }, + TransactionData, TxVersion, + }, + zip32::AccountId, + }, + }; + + use crate::{ + tests, + wallet::init::{ + init_wallet_db, init_wallet_db_internal, + migrations::addresses_table::ADDRESSES_TABLE_MIGRATION, + }, + WalletDb, + }; + + #[test] + fn transaction_views() { + let data_file = NamedTempFile::new().unwrap(); + let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); + init_wallet_db_internal(&mut db_data, None, Some(ADDRESSES_TABLE_MIGRATION)).unwrap(); + + db_data.conn.execute_batch( + "INSERT INTO accounts (account, ufvk) VALUES (0, ''); + INSERT INTO blocks (height, hash, time, sapling_tree) VALUES (0, 0, 0, ''); + INSERT INTO transactions (block, id_tx, txid) VALUES (0, 0, ''); + + INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value) + VALUES (0, 2, 0, 0, '', 2); + INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value, memo) + VALUES (0, 2, 1, 0, '', 3, X'61'); + INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value, memo) + VALUES (0, 2, 2, 0, '', 0, X'F600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'); + + INSERT INTO received_notes (tx, output_index, account, diversifier, value, rcm, nf, is_change) + VALUES (0, 0, 0, '', 2, '', 'a', false); + INSERT INTO received_notes (tx, output_index, account, diversifier, value, rcm, nf, is_change, memo) + VALUES (0, 3, 0, '', 5, '', 'b', false, X'62'); + INSERT INTO received_notes (tx, output_index, account, diversifier, value, rcm, nf, is_change, memo) + VALUES (0, 4, 0, '', 7, '', 'c', true, X'63');", + ).unwrap(); + + init_wallet_db(&mut db_data, None).unwrap(); + + let mut q = db_data + .conn + .prepare("SELECT received_total, received_note_count, memo_count FROM v_tx_received") + .unwrap(); + let mut rows = q.query(NO_PARAMS).unwrap(); + let mut row_count = 0; + while let Some(row) = rows.next().unwrap() { + row_count += 1; + let total: i64 = row.get(0).unwrap(); + let count: i64 = row.get(1).unwrap(); + let memo_count: i64 = row.get(2).unwrap(); + assert_eq!(total, 14); + assert_eq!(count, 3); + assert_eq!(memo_count, 2); + } + assert_eq!(row_count, 1); + + let mut q = db_data + .conn + .prepare("SELECT sent_total, sent_note_count, memo_count FROM v_tx_sent") + .unwrap(); + let mut rows = q.query(NO_PARAMS).unwrap(); + let mut row_count = 0; + while let Some(row) = rows.next().unwrap() { + row_count += 1; + let total: i64 = row.get(0).unwrap(); + let count: i64 = row.get(1).unwrap(); + let memo_count: i64 = row.get(2).unwrap(); + assert_eq!(total, 5); + assert_eq!(count, 3); + assert_eq!(memo_count, 1); + } + assert_eq!(row_count, 1); + + let mut q = db_data + .conn + .prepare("SELECT net_value, has_change, memo_count FROM v_transactions") + .unwrap(); + let mut rows = q.query(NO_PARAMS).unwrap(); + let mut row_count = 0; + while let Some(row) = rows.next().unwrap() { + row_count += 1; + let net_value: i64 = row.get(0).unwrap(); + let has_change: bool = row.get(1).unwrap(); + let memo_count: i64 = row.get(2).unwrap(); + assert_eq!(net_value, 2); + assert!(has_change); + assert_eq!(memo_count, 3); + } + assert_eq!(row_count, 1); + } + + #[test] + #[cfg(feature = "transparent-inputs")] + fn migrate_from_wm2() { + let data_file = NamedTempFile::new().unwrap(); + let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); + init_wallet_db_internal(&mut db_data, None, Some(WalletMigration2::::id())) + .unwrap(); + + // create a UTXO to spend + let tx = TransactionData::from_parts( + TxVersion::Sapling, + BranchId::Canopy, + 0, + BlockHeight::from(3), + Some(transparent::Bundle { + vin: vec![TxIn { + prevout: OutPoint::new([1u8; 32], 1), + script_sig: Script(vec![]), + sequence: 0, + }], + vout: vec![TxOut { + value: Amount::from_i64(1100000000).unwrap(), + script_pubkey: Script(vec![]), + }], + authorization: Authorized, + }), + None, + None, + None, + ) + .freeze() + .unwrap(); + + let mut tx_bytes = vec![]; + tx.write(&mut tx_bytes).unwrap(); + + let usk = + UnifiedSpendingKey::from_seed(&tests::network(), &[0u8; 32][..], AccountId::from(0)) + .unwrap(); + let ufvk = usk.to_unified_full_viewing_key(); + let (ua, _) = ufvk.default_address(); + let taddr = ufvk + .transparent() + .and_then(|k| { + k.derive_external_ivk() + .ok() + .map(|k| k.derive_address(0).unwrap()) + }) + .map(|a| a.encode(&tests::network())); + + db_data.conn.execute( + "INSERT INTO accounts (account, ufvk, address, transparent_address) VALUES (0, ?, ?, ?)", + params![ufvk.encode(&tests::network()), ua.encode(&tests::network()), &taddr] + ).unwrap(); + db_data + .conn + .execute_batch( + "INSERT INTO blocks (height, hash, time, sapling_tree) VALUES (0, 0, 0, '');", + ) + .unwrap(); + db_data.conn.execute( + "INSERT INTO utxos (address, prevout_txid, prevout_idx, script, value_zat, height) + VALUES (?, X'0101010101010101010101010101010101010101010101010101010101010101', 1, X'', 1400000000, 1)", + &[taddr] + ).unwrap(); + db_data + .conn + .execute( + "INSERT INTO transactions (block, id_tx, txid, raw) VALUES (0, 0, '', ?)", + params![tx_bytes], + ) + .unwrap(); + + init_wallet_db(&mut db_data, None).unwrap(); + + let fee = db_data + .conn + .query_row( + "SELECT fee FROM transactions WHERE id_tx = 0", + NO_PARAMS, + |row| Ok(Amount::from_i64(row.get(0)?).unwrap()), + ) + .unwrap(); + + assert_eq!(fee, Amount::from_i64(300000000).unwrap()); + } +} diff --git a/zcash_extensions/src/transparent/demo.rs b/zcash_extensions/src/transparent/demo.rs index e9150823b..61240fa39 100644 --- a/zcash_extensions/src/transparent/demo.rs +++ b/zcash_extensions/src/transparent/demo.rs @@ -685,7 +685,7 @@ mod tests { precondition: tze::Precondition::from(0, &Precondition::open(hash_1)), }; - let tx_a = TransactionData::from_parts( + let tx_a = TransactionData::from_parts_zfuture( TxVersion::ZFuture, BranchId::ZFuture, 0, @@ -716,7 +716,7 @@ mod tests { precondition: tze::Precondition::from(0, &Precondition::close(hash_2)), }; - let tx_b = TransactionData::from_parts( + let tx_b = TransactionData::from_parts_zfuture( TxVersion::ZFuture, BranchId::ZFuture, 0, @@ -743,7 +743,7 @@ mod tests { witness: tze::Witness::from(0, &Witness::close(preimage_2)), }; - let tx_c = TransactionData::from_parts( + let tx_c = TransactionData::from_parts_zfuture( TxVersion::ZFuture, BranchId::ZFuture, 0, diff --git a/zcash_primitives/CHANGELOG.md b/zcash_primitives/CHANGELOG.md index 75d1dbdd0..43bf96ae7 100644 --- a/zcash_primitives/CHANGELOG.md +++ b/zcash_primitives/CHANGELOG.md @@ -23,6 +23,16 @@ and this library adheres to Rust's notion of - `DiversifierIndex::{as_bytes}` - `ExtendedSpendingKey::{from_bytes, to_bytes}` - Implementations of `From` and `From` for `DiversifierIndex` +- `zcash_primitives::transaction::Builder` constructors: + - `Builder::new_with_fee` + - `Builder::new_with_rng_and_fee` +- `zcash_primitives::transaction::TransactionData::fee_paid` +- `zcash_primitives::transaction::components::amount::BalanceError` +- Added in `zcash_primitives::transaction::components::sprout` + - `Bundle::value_balance` + - `JSDescription::net_value` +- Added in `zcash_primitives::transaction::components::transparent` + - `Bundle::value_balance` ### Changed - `zcash_primitives::sapling::ViewingKey` now stores `nk` as a diff --git a/zcash_primitives/src/transaction/builder.rs b/zcash_primitives/src/transaction/builder.rs index d2dacc655..21dad4086 100644 --- a/zcash_primitives/src/transaction/builder.rs +++ b/zcash_primitives/src/transaction/builder.rs @@ -141,6 +141,18 @@ impl<'a, P: consensus::Parameters> Builder<'a, P, OsRng> { pub fn new(params: P, target_height: BlockHeight) -> Self { Builder::new_with_rng(params, target_height, OsRng) } + + /// Creates a new `Builder` targeted for inclusion in the block with the given height, using + /// the specified fee, and otherwise default values for general transaction fields and the + /// default OS random. + /// + /// # Default values + /// + /// The expiry height will be set to the given height plus the default transaction + /// expiry delta (20 blocks). + pub fn new_with_fee(params: P, target_height: BlockHeight, fee: Amount) -> Self { + Builder::new_with_rng_and_fee(params, OsRng, target_height, fee) + } } impl<'a, P: consensus::Parameters, R: RngCore + CryptoRng> Builder<'a, P, R> { @@ -154,7 +166,24 @@ impl<'a, P: consensus::Parameters, R: RngCore + CryptoRng> Builder<'a, P, R> { /// /// The fee will be set to the default fee (0.0001 ZEC). pub fn new_with_rng(params: P, target_height: BlockHeight, rng: R) -> Builder<'a, P, R> { - Self::new_internal(params, target_height, rng) + Self::new_internal(params, rng, target_height, DEFAULT_FEE) + } + + /// Creates a new `Builder` targeted for inclusion in the block with the given height, and + /// randomness source, using the specified fee, and otherwise default values for general + /// transaction fields and the default OS random. + /// + /// # Default values + /// + /// The expiry height will be set to the given height plus the default transaction + /// expiry delta (20 blocks). + pub fn new_with_rng_and_fee( + params: P, + rng: R, + target_height: BlockHeight, + fee: Amount, + ) -> Builder<'a, P, R> { + Self::new_internal(params, rng, target_height, fee) } } @@ -163,13 +192,18 @@ impl<'a, P: consensus::Parameters, R: RngCore> Builder<'a, P, R> { /// /// WARNING: THIS MUST REMAIN PRIVATE AS IT ALLOWS CONSTRUCTION /// OF BUILDERS WITH NON-CryptoRng RNGs - fn new_internal(params: P, target_height: BlockHeight, rng: R) -> Builder<'a, P, R> { + fn new_internal( + params: P, + rng: R, + target_height: BlockHeight, + fee: Amount, + ) -> Builder<'a, P, R> { Builder { params: params.clone(), rng, target_height, expiry_height: target_height + DEFAULT_TX_EXPIRY_DELTA, - fee: DEFAULT_FEE, + fee, transparent_builder: TransparentBuilder::empty(), sapling_builder: SaplingBuilder::new(params, target_height), change_address: None, @@ -454,7 +488,7 @@ impl<'a, P: consensus::Parameters, R: RngCore> Builder<'a, P, R> { /// /// WARNING: DO NOT USE IN PRODUCTION pub fn test_only_new_with_rng(params: P, height: BlockHeight, rng: R) -> Builder<'a, P, R> { - Self::new_internal(params, height, rng) + Self::new_internal(params, rng, height, DEFAULT_FEE) } pub fn mock_build(self) -> Result<(Transaction, SaplingMetadata), Error> { diff --git a/zcash_primitives/src/transaction/components/amount.rs b/zcash_primitives/src/transaction/components/amount.rs index 77655f1da..107f14c1d 100644 --- a/zcash_primitives/src/transaction/components/amount.rs +++ b/zcash_primitives/src/transaction/components/amount.rs @@ -222,6 +222,14 @@ impl TryFrom for Amount { } } +/// A type for balance violations in amount addition and subtraction +/// (overflow and underflow of allowed ranges) +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum BalanceError { + Overflow, + Underflow, +} + #[cfg(any(test, feature = "test-dependencies"))] pub mod testing { use proptest::prelude::prop_compose; diff --git a/zcash_primitives/src/transaction/components/sprout.rs b/zcash_primitives/src/transaction/components/sprout.rs index 0c11b0e04..b3e2370f7 100644 --- a/zcash_primitives/src/transaction/components/sprout.rs +++ b/zcash_primitives/src/transaction/components/sprout.rs @@ -17,6 +17,18 @@ pub struct Bundle { pub joinsplit_sig: [u8; 64], } +impl Bundle { + /// The value balance for the bundle. When this is positive, + /// its value is added to the transparent value pool; when it + /// is negative, its value is subtracted from the transparent + /// value pool. + pub fn value_balance(&self) -> Option { + self.joinsplits + .iter() + .try_fold(Amount::zero(), |total, js| total + js.net_value()) + } +} + #[derive(Clone)] #[allow(clippy::upper_case_acronyms)] pub(crate) enum SproutProof { @@ -172,4 +184,12 @@ impl JsDescription { writer.write_all(&self.ciphertexts[0])?; writer.write_all(&self.ciphertexts[1]) } + + /// The net value for the JoinSplit. When this is positive, + /// its value is added to the transparent value pool; when it + /// is negative, its value is subtracted from the transparent + /// value pool. + pub fn net_value(&self) -> Amount { + (self.vpub_new - self.vpub_old).expect("difference is in range [-MAX_MONEY..=MAX_MONEY]") + } } diff --git a/zcash_primitives/src/transaction/components/transparent.rs b/zcash_primitives/src/transaction/components/transparent.rs index f5a132f64..bae708158 100644 --- a/zcash_primitives/src/transaction/components/transparent.rs +++ b/zcash_primitives/src/transaction/components/transparent.rs @@ -7,7 +7,7 @@ use std::io::{self, Read, Write}; use crate::legacy::Script; -use super::amount::Amount; +use super::amount::{Amount, BalanceError}; pub mod builder; @@ -62,6 +62,31 @@ impl Bundle { authorization: f.map_authorization(self.authorization), } } + + /// The amount of value added to or removed from the transparent pool by the action of this + /// bundle. A positive value represents that the containing transaction has funds being + /// transferred out of the transparent pool into shielded pools or to fees; a negative value + /// means that the containing transaction has funds being transferred into the transparent pool + /// from the shielded pools. + pub fn value_balance(&self, mut get_prevout_value: F) -> Result + where + E: From, + F: FnMut(&OutPoint) -> Result, + { + let input_sum = self.vin.iter().try_fold(Amount::zero(), |total, txin| { + get_prevout_value(&txin.prevout) + .and_then(|v| (total + v).ok_or_else(|| BalanceError::Overflow.into())) + })?; + + let output_sum = self + .vout + .iter() + .map(|p| p.value) + .sum::>() + .ok_or(BalanceError::Overflow)?; + + (input_sum - output_sum).ok_or_else(|| BalanceError::Underflow.into()) + } } #[derive(Clone, Debug, PartialEq)] diff --git a/zcash_primitives/src/transaction/mod.rs b/zcash_primitives/src/transaction/mod.rs index 858c39ee4..d82528da3 100644 --- a/zcash_primitives/src/transaction/mod.rs +++ b/zcash_primitives/src/transaction/mod.rs @@ -27,13 +27,14 @@ use crate::{ use self::{ components::{ - amount::Amount, + amount::{Amount, BalanceError}, orchard as orchard_serialization, sapling::{ self, OutputDescription, OutputDescriptionV5, SpendDescription, SpendDescriptionV5, }, sprout::{self, JsDescription}, transparent::{self, TxIn, TxOut}, + OutPoint, }, txid::{to_txid, BlockTxCommitmentDigester, TxIdDigester}, util::sha256d::{HashReader, HashWriter}, @@ -317,7 +318,6 @@ impl TransactionData { sprout_bundle: Option, sapling_bundle: Option>, orchard_bundle: Option>, - #[cfg(feature = "zfuture")] tze_bundle: Option>, ) -> Self { TransactionData { version, @@ -329,6 +329,32 @@ impl TransactionData { sapling_bundle, orchard_bundle, #[cfg(feature = "zfuture")] + tze_bundle: None, + } + } + + #[cfg(feature = "zfuture")] + #[allow(clippy::too_many_arguments)] + pub fn from_parts_zfuture( + version: TxVersion, + consensus_branch_id: BranchId, + lock_time: u32, + expiry_height: BlockHeight, + transparent_bundle: Option>, + sprout_bundle: Option, + sapling_bundle: Option>, + orchard_bundle: Option>, + tze_bundle: Option>, + ) -> Self { + TransactionData { + version, + consensus_branch_id, + lock_time, + expiry_height, + transparent_bundle, + sprout_bundle, + sapling_bundle, + orchard_bundle, tze_bundle, } } @@ -370,6 +396,36 @@ impl TransactionData { self.tze_bundle.as_ref() } + /// Returns the total fees paid by the transaction, given a function that can be used to + /// retrieve the value of previous transactions' transparent outputs that are being spent in + /// this transaction. + pub fn fee_paid(&self, get_prevout: F) -> Result + where + E: From, + F: FnMut(&OutPoint) -> Result, + { + let value_balances = [ + self.transparent_bundle + .as_ref() + .map_or_else(|| Ok(Amount::zero()), |b| b.value_balance(get_prevout))?, + self.sprout_bundle.as_ref().map_or_else( + || Ok(Amount::zero()), + |b| b.value_balance().ok_or(BalanceError::Overflow), + )?, + self.sapling_bundle + .as_ref() + .map_or_else(Amount::zero, |b| b.value_balance), + self.orchard_bundle + .as_ref() + .map_or_else(Amount::zero, |b| *b.value_balance()), + ]; + + value_balances + .iter() + .sum::>() + .ok_or_else(|| BalanceError::Overflow.into()) + } + pub fn digest>(&self, digester: D) -> D::Digest { digester.combine( digester.digest_header( diff --git a/zcash_primitives/src/transaction/tests.rs b/zcash_primitives/src/transaction/tests.rs index 1419479a1..9335d8984 100644 --- a/zcash_primitives/src/transaction/tests.rs +++ b/zcash_primitives/src/transaction/tests.rs @@ -245,21 +245,30 @@ fn zip_0244() { }, }); - ( - TransactionData::from_parts( - txdata.version(), - txdata.consensus_branch_id(), - txdata.lock_time(), - txdata.expiry_height(), - test_bundle, - txdata.sprout_bundle().cloned(), - txdata.sapling_bundle().cloned(), - txdata.orchard_bundle().cloned(), - #[cfg(feature = "zfuture")] - txdata.tze_bundle().cloned(), - ), - txdata.digest(TxIdDigester), - ) + #[cfg(not(feature = "zfuture"))] + let tdata = TransactionData::from_parts( + txdata.version(), + txdata.consensus_branch_id(), + txdata.lock_time(), + txdata.expiry_height(), + test_bundle, + txdata.sprout_bundle().cloned(), + txdata.sapling_bundle().cloned(), + txdata.orchard_bundle().cloned(), + ); + #[cfg(feature = "zfuture")] + let tdata = TransactionData::from_parts_zfuture( + txdata.version(), + txdata.consensus_branch_id(), + txdata.lock_time(), + txdata.expiry_height(), + test_bundle, + txdata.sprout_bundle().cloned(), + txdata.sapling_bundle().cloned(), + txdata.orchard_bundle().cloned(), + txdata.tze_bundle().cloned(), + ); + (tdata, txdata.digest(TxIdDigester)) } for tv in self::data::zip_0244::make_test_vectors() {