diff --git a/zcash_client_sqlite/src/wallet/init.rs b/zcash_client_sqlite/src/wallet/init.rs index 037b89e9f..a8dae4d6c 100644 --- a/zcash_client_sqlite/src/wallet/init.rs +++ b/zcash_client_sqlite/src/wallet/init.rs @@ -9,11 +9,8 @@ use uuid::Uuid; use zcash_primitives::{ block::BlockHash, - consensus::{self, BlockHeight, BranchId}, - transaction::{ - components::amount::{Amount, BalanceError}, - Transaction, - }, + consensus::{self, BlockHeight}, + transaction::components::amount::BalanceError, zip32::AccountId, }; @@ -217,9 +214,15 @@ struct WalletMigration2

{ seed: Option>, } +impl

WalletMigration2

{ + fn id() -> Uuid { + Uuid::parse_str("be57ef3b-388e-42ea-97e2-678dafcf9754").unwrap() + } +} + impl

Migration for WalletMigration2

{ fn id(&self) -> Uuid { - ::uuid::Uuid::parse_str("be57ef3b-388e-42ea-97e2-678dafcf9754").unwrap() + WalletMigration2::

::id() } fn dependencies(&self) -> HashSet { @@ -449,180 +452,6 @@ impl RusqliteMigration for WalletMigration2

{ } } -struct WalletMigrationAddTxViews

{ - params: P, -} - -impl

Migration for WalletMigrationAddTxViews

{ - fn id(&self) -> Uuid { - ::uuid::Uuid::parse_str("282fad2e-8372-4ca0-8bed-71821320909f").unwrap() - } - - fn dependencies(&self) -> HashSet { - ["be57ef3b-388e-42ea-97e2-678dafcf9754"] - .iter() - .map(|uuidstr| ::uuid::Uuid::parse_str(uuidstr).unwrap()) - .collect() - } - - fn description(&self) -> &'static str { - "Add transaction summary views & add fee information to transactions." - } -} - -impl RusqliteMigration for WalletMigrationAddTxViews

{ - type Error = WalletMigrationError; - - fn up(&self, transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { - transaction.execute_batch("ALTER TABLE transactions ADD COLUMN fee INTEGER;")?; - - let mut stmt_list_txs = - transaction.prepare("SELECT id_tx, raw, block FROM transactions")?; - - let mut stmt_set_fee = - transaction.prepare("UPDATE transactions SET fee = ? WHERE id_tx = ?")?; - - let mut stmt_find_utxo_value = transaction - .prepare("SELECT value_zat FROM utxos WHERE prevout_txid = ? AND prevout_idx = ?")?; - - let mut tx_rows = stmt_list_txs.query(NO_PARAMS)?; - while let Some(row) = tx_rows.next()? { - let id_tx: i64 = row.get(0)?; - let tx_bytes: Vec = row.get(1)?; - let h: u32 = row.get(2)?; - let block_height = BlockHeight::from(h); - - let tx = Transaction::read( - &tx_bytes[..], - BranchId::for_height(&self.params, block_height), - ) - .map_err(|e| { - WalletMigrationError::CorruptedData(format!( - "Parsing failed for transaction {:?}: {:?}", - id_tx, e - )) - })?; - - let fee_paid = tx.fee_paid(|op| { - stmt_find_utxo_value - .query_row(&[op.hash().to_sql()?, op.n().to_sql()?], |row| { - row.get(0).map(|i| Amount::from_i64(i).unwrap()) - }) - .map_err(WalletMigrationError::DbError) - })?; - - stmt_set_fee.execute(&[i64::from(fee_paid), id_tx])?; - } - - transaction.execute_batch( - "CREATE VIEW v_tx_sent AS - SELECT transactions.id_tx AS id_tx, - transactions.block AS mined_height, - transactions.tx_index AS tx_index, - transactions.txid AS txid, - transactions.expiry_height AS expiry_height, - transactions.raw AS raw, - SUM(sent_notes.value) AS sent_total, - COUNT(sent_notes.id_note) AS sent_note_count, - SUM( - CASE - WHEN sent_notes.memo IS NULL THEN 0 - WHEN SUBSTR(sent_notes.memo, 0, 2) = X'F6' THEN 0 - ELSE 1 - END - ) AS memo_count, - blocks.time AS block_time - FROM transactions - JOIN sent_notes - ON transactions.id_tx = sent_notes.tx - LEFT JOIN blocks - ON transactions.block = blocks.height - GROUP BY sent_notes.tx; - CREATE VIEW v_tx_received AS - SELECT transactions.id_tx AS id_tx, - transactions.block AS mined_height, - transactions.tx_index AS tx_index, - transactions.txid AS txid, - SUM(received_notes.value) AS received_total, - COUNT(received_notes.id_note) AS received_note_count, - SUM( - CASE - WHEN received_notes.memo IS NULL THEN 0 - WHEN SUBSTR(received_notes.memo, 0, 2) = X'F6' THEN 0 - ELSE 1 - END - ) AS memo_count, - blocks.time AS block_time - FROM transactions - JOIN received_notes - ON transactions.id_tx = received_notes.tx - LEFT JOIN blocks - ON transactions.block = blocks.height - GROUP BY received_notes.tx; - CREATE VIEW v_transactions AS - SELECT id_tx, - mined_height, - tx_index, - txid, - expiry_height, - raw, - SUM(value) + MAX(fee) AS net_value, - SUM(is_change) > 0 AS has_change, - SUM(memo_present) AS memo_count - FROM ( - SELECT transactions.id_tx AS id_tx, - transactions.block AS mined_height, - transactions.tx_index AS tx_index, - transactions.txid AS txid, - transactions.expiry_height AS expiry_height, - transactions.raw AS raw, - 0 AS fee, - CASE - WHEN received_notes.is_change THEN 0 - ELSE value - END AS value, - received_notes.is_change AS is_change, - CASE - WHEN received_notes.memo IS NULL THEN 0 - WHEN SUBSTR(received_notes.memo, 0, 2) = X'F6' THEN 0 - ELSE 1 - END AS memo_present - FROM transactions - JOIN received_notes ON transactions.id_tx = received_notes.tx - UNION - SELECT transactions.id_tx AS id_tx, - transactions.block AS mined_height, - transactions.tx_index AS tx_index, - transactions.txid AS txid, - transactions.expiry_height AS expiry_height, - transactions.raw AS raw, - transactions.fee AS fee, - -sent_notes.value AS value, - false AS is_change, - CASE - WHEN sent_notes.memo IS NULL THEN 0 - WHEN SUBSTR(sent_notes.memo, 0, 2) = X'F6' THEN 0 - ELSE 1 - END AS memo_present - FROM transactions - JOIN sent_notes ON transactions.id_tx = sent_notes.tx - ) - GROUP BY id_tx;", - )?; - - Ok(()) - } - - fn down(&self, transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { - transaction.execute_batch( - "DROP VIEW v_tx_sent_notes; - DROP VIEW v_tx_received_notes; - DROP VIEW v_tx_notes;", - )?; - Ok(()) - } -} - /// Sets up the internal structure of the data database. /// /// This procedure will automatically perform migration operations to update the wallet database to @@ -661,6 +490,14 @@ impl RusqliteMigration for WalletMigrationAddTxViews

( wdb: &mut WalletDb

, seed: Option>, +) -> Result<(), MigratorError> { + init_wallet_db_internal(wdb, seed, None) +} + +fn init_wallet_db_internal( + wdb: &mut WalletDb

, + seed: Option>, + target_migration: Option, ) -> Result<(), MigratorError> { wdb.conn .execute("PRAGMA foreign_keys = OFF", NO_PARAMS) @@ -669,25 +506,10 @@ pub fn init_wallet_db( adapter.init().expect("Migrations table setup succeeds."); let mut migrator = Migrator::new(adapter); - let migration0 = Box::new(WalletMigration0 {}); - let migration1 = Box::new(WalletMigration1 {}); - let migration2 = Box::new(WalletMigration2 { - params: wdb.params.clone(), - seed, - }); - let migration3 = Box::new(WalletMigrationAddTxViews { - params: wdb.params.clone(), - }); - let migration4 = Box::new(migrations::AddressesTableMigration { - params: wdb.params.clone(), - }); - migrator - .register_multiple(vec![ - migration0, migration1, migration2, migration3, migration4, - ]) + .register_multiple(migrations::all_migrations(&wdb.params, seed)) .expect("Wallet migration registration should have been successful."); - migrator.up(None)?; + migrator.up(target_migration)?; wdb.conn .execute("PRAGMA foreign_keys = ON", NO_PARAMS) .map_err(|e| MigratorError::Adapter(WalletMigrationError::from(e)))?; @@ -1539,82 +1361,6 @@ mod tests { init_wallet_db(&mut db_data, Some(Secret::new(seed.to_vec()))).unwrap(); } - #[test] - fn transaction_views() { - let data_file = NamedTempFile::new().unwrap(); - let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); - init_wallet_db(&mut db_data, None).unwrap(); - - db_data.conn.execute_batch( - "INSERT INTO accounts (account, ufvk) VALUES (0, ''); - INSERT INTO blocks (height, hash, time, sapling_tree) VALUES (0, 0, 0, ''); - INSERT INTO transactions (block, id_tx, txid) VALUES (0, 0, ''); - - INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value) - VALUES (0, 2, 0, 0, '', 2); - INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value, memo) - VALUES (0, 2, 1, 0, '', 3, X'61'); - INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value, memo) - VALUES (0, 2, 2, 0, '', 0, X'f600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'); - - INSERT INTO received_notes (tx, output_index, account, diversifier, value, rcm, nf, is_change, memo) - VALUES (0, 0, 0, '', 5, '', 'a', false, X'62'); - INSERT INTO received_notes (tx, output_index, account, diversifier, value, rcm, nf, is_change, memo) - VALUES (0, 1, 0, '', 7, '', 'b', true, X'63');", - ).unwrap(); - - let mut q = db_data - .conn - .prepare("SELECT received_total, received_note_count, memo_count FROM v_tx_received") - .unwrap(); - let mut rows = q.query(NO_PARAMS).unwrap(); - let mut row_count = 0; - while let Some(row) = rows.next().unwrap() { - row_count += 1; - let total: i64 = row.get(0).unwrap(); - let count: i64 = row.get(1).unwrap(); - let memo_count: i64 = row.get(2).unwrap(); - assert_eq!(total, 12); - assert_eq!(count, 2); - assert_eq!(memo_count, 2); - } - assert_eq!(row_count, 1); - - let mut q = db_data - .conn - .prepare("SELECT sent_total, sent_note_count, memo_count FROM v_tx_sent") - .unwrap(); - let mut rows = q.query(NO_PARAMS).unwrap(); - let mut row_count = 0; - while let Some(row) = rows.next().unwrap() { - row_count += 1; - let total: i64 = row.get(0).unwrap(); - let count: i64 = row.get(1).unwrap(); - let memo_count: i64 = row.get(2).unwrap(); - assert_eq!(total, 5); - assert_eq!(count, 3); - assert_eq!(memo_count, 1); - } - assert_eq!(row_count, 1); - - let mut q = db_data - .conn - .prepare("SELECT net_value, has_change, memo_count FROM v_transactions") - .unwrap(); - let mut rows = q.query(NO_PARAMS).unwrap(); - let mut row_count = 0; - while let Some(row) = rows.next().unwrap() { - row_count += 1; - let net_value: i64 = row.get(0).unwrap(); - let has_change: bool = row.get(1).unwrap(); - let memo_count: i64 = row.get(2).unwrap(); - assert_eq!(net_value, 0); - assert!(has_change); - assert_eq!(memo_count, 3); - } - assert_eq!(row_count, 1); - } - #[test] fn init_accounts_table_only_works_once() { let data_file = NamedTempFile::new().unwrap(); diff --git a/zcash_client_sqlite/src/wallet/init/migrations.rs b/zcash_client_sqlite/src/wallet/init/migrations.rs index deea2236c..ab56dc203 100644 --- a/zcash_client_sqlite/src/wallet/init/migrations.rs +++ b/zcash_client_sqlite/src/wallet/init/migrations.rs @@ -1,2 +1,30 @@ mod addresses_table; pub(super) use addresses_table::AddressesTableMigration; + +mod add_transaction_views; + +use schemer_rusqlite::RusqliteMigration; +use secrecy::SecretVec; +use zcash_primitives::consensus; + +use super::{WalletMigration0, WalletMigration1, WalletMigration2, WalletMigrationError}; + +pub(super) fn all_migrations( + params: &P, + seed: Option>, +) -> Vec>> { + vec![ + Box::new(WalletMigration0 {}), + Box::new(WalletMigration1 {}), + Box::new(WalletMigration2 { + params: params.clone(), + seed, + }), + Box::new(AddressesTableMigration { + params: params.clone(), + }), + Box::new(add_transaction_views::Migration { + params: params.clone(), + }), + ] +} diff --git a/zcash_client_sqlite/src/wallet/init/migrations/add_transaction_views.rs b/zcash_client_sqlite/src/wallet/init/migrations/add_transaction_views.rs new file mode 100644 index 000000000..d061ca200 --- /dev/null +++ b/zcash_client_sqlite/src/wallet/init/migrations/add_transaction_views.rs @@ -0,0 +1,279 @@ +//! Functions for initializing the various databases. +use rusqlite::{self, types::ToSql, NO_PARAMS}; +use schemer::{self}; +use schemer_rusqlite::RusqliteMigration; + +use std::collections::HashSet; + +use uuid::Uuid; + +use zcash_primitives::{ + consensus::{self, BlockHeight, BranchId}, + transaction::{components::amount::Amount, Transaction}, +}; + +use super::super::{WalletMigration2, WalletMigrationError}; + +pub(super) struct Migration

{ + pub(super) params: P, +} + +impl

Migration

{ + fn id() -> Uuid { + Uuid::parse_str("282fad2e-8372-4ca0-8bed-71821320909f").unwrap() + } +} + +impl

schemer::Migration for Migration

{ + fn id(&self) -> Uuid { + Migration::

::id() + } + + fn dependencies(&self) -> HashSet { + let mut deps = HashSet::new(); + deps.insert(WalletMigration2::

::id()); + deps + } + + fn description(&self) -> &'static str { + "Add transaction summary views & add fee information to transactions." + } +} + +impl RusqliteMigration for Migration

{ + type Error = WalletMigrationError; + + fn up(&self, transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { + transaction.execute_batch("ALTER TABLE transactions ADD COLUMN fee INTEGER;")?; + + let mut stmt_list_txs = + transaction.prepare("SELECT id_tx, raw, block FROM transactions")?; + + let mut stmt_set_fee = + transaction.prepare("UPDATE transactions SET fee = ? WHERE id_tx = ?")?; + + let mut stmt_find_utxo_value = transaction + .prepare("SELECT value_zat FROM utxos WHERE prevout_txid = ? AND prevout_idx = ?")?; + + let mut tx_rows = stmt_list_txs.query(NO_PARAMS)?; + while let Some(row) = tx_rows.next()? { + let id_tx: i64 = row.get(0)?; + let tx_bytes: Vec = row.get(1)?; + let h: u32 = row.get(2)?; + let block_height = BlockHeight::from(h); + + let tx = Transaction::read( + &tx_bytes[..], + BranchId::for_height(&self.params, block_height), + ) + .map_err(|e| { + WalletMigrationError::CorruptedData(format!( + "Parsing failed for transaction {:?}: {:?}", + id_tx, e + )) + })?; + + let fee_paid = tx.fee_paid(|op| { + stmt_find_utxo_value + .query_row(&[op.hash().to_sql()?, op.n().to_sql()?], |row| { + row.get(0).map(|i| Amount::from_i64(i).unwrap()) + }) + .map_err(WalletMigrationError::DbError) + })?; + + stmt_set_fee.execute(&[i64::from(fee_paid), id_tx])?; + } + + transaction.execute_batch( + "CREATE VIEW v_tx_sent AS + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + transactions.expiry_height AS expiry_height, + transactions.raw AS raw, + SUM(sent_notes.value) AS sent_total, + COUNT(sent_notes.id_note) AS sent_note_count, + SUM( + CASE + WHEN sent_notes.memo IS NULL THEN 0 + WHEN SUBSTR(sent_notes.memo, 0, 2) = X'F6' THEN 0 + ELSE 1 + END + ) AS memo_count, + blocks.time AS block_time + FROM transactions + JOIN sent_notes + ON transactions.id_tx = sent_notes.tx + LEFT JOIN blocks + ON transactions.block = blocks.height + GROUP BY sent_notes.tx; + CREATE VIEW v_tx_received AS + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + SUM(received_notes.value) AS received_total, + COUNT(received_notes.id_note) AS received_note_count, + SUM( + CASE + WHEN received_notes.memo IS NULL THEN 0 + WHEN SUBSTR(received_notes.memo, 0, 2) = X'F6' THEN 0 + ELSE 1 + END + ) AS memo_count, + blocks.time AS block_time + FROM transactions + JOIN received_notes + ON transactions.id_tx = received_notes.tx + LEFT JOIN blocks + ON transactions.block = blocks.height + GROUP BY received_notes.tx; + CREATE VIEW v_transactions AS + SELECT id_tx, + mined_height, + tx_index, + txid, + expiry_height, + raw, + SUM(value) + MAX(fee) AS net_value, + SUM(is_change) > 0 AS has_change, + SUM(memo_present) AS memo_count + FROM ( + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + transactions.expiry_height AS expiry_height, + transactions.raw AS raw, + 0 AS fee, + CASE + WHEN received_notes.is_change THEN 0 + ELSE value + END AS value, + received_notes.is_change AS is_change, + CASE + WHEN received_notes.memo IS NULL THEN 0 + WHEN SUBSTR(received_notes.memo, 0, 2) = X'F6' THEN 0 + ELSE 1 + END AS memo_present + FROM transactions + JOIN received_notes ON transactions.id_tx = received_notes.tx + UNION + SELECT transactions.id_tx AS id_tx, + transactions.block AS mined_height, + transactions.tx_index AS tx_index, + transactions.txid AS txid, + transactions.expiry_height AS expiry_height, + transactions.raw AS raw, + transactions.fee AS fee, + -sent_notes.value AS value, + false AS is_change, + CASE + WHEN sent_notes.memo IS NULL THEN 0 + WHEN SUBSTR(sent_notes.memo, 0, 2) = X'F6' THEN 0 + ELSE 1 + END AS memo_present + FROM transactions + JOIN sent_notes ON transactions.id_tx = sent_notes.tx + ) + GROUP BY id_tx;", + )?; + + Ok(()) + } + + fn down(&self, _transaction: &rusqlite::Transaction) -> Result<(), WalletMigrationError> { + // TODO: something better than just panic? + panic!("Cannot revert this migration."); + } +} + +#[cfg(test)] +mod tests { + use rusqlite::{self, NO_PARAMS}; + + use tempfile::NamedTempFile; + + use crate::{ + tests::{self}, + wallet::init::init_wallet_db, + WalletDb, + }; + + #[test] + fn transaction_views() { + let data_file = NamedTempFile::new().unwrap(); + let mut db_data = WalletDb::for_path(data_file.path(), tests::network()).unwrap(); + init_wallet_db(&mut db_data, None).unwrap(); + + db_data.conn.execute_batch( + "INSERT INTO accounts (account, ufvk) VALUES (0, ''); + INSERT INTO blocks (height, hash, time, sapling_tree) VALUES (0, 0, 0, ''); + INSERT INTO transactions (block, id_tx, txid) VALUES (0, 0, ''); + + INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value) + VALUES (0, 2, 0, 0, '', 2); + INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value, memo) + VALUES (0, 2, 1, 0, '', 3, X'61'); + INSERT INTO sent_notes (tx, output_pool, output_index, from_account, address, value, memo) + VALUES (0, 2, 2, 0, '', 0, X'f600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'); + + INSERT INTO received_notes (tx, output_index, account, diversifier, value, rcm, nf, is_change, memo) + VALUES (0, 0, 0, '', 5, '', 'a', false, X'62'); + INSERT INTO received_notes (tx, output_index, account, diversifier, value, rcm, nf, is_change, memo) + VALUES (0, 1, 0, '', 7, '', 'b', true, X'63');", + ).unwrap(); + + let mut q = db_data + .conn + .prepare("SELECT received_total, received_note_count, memo_count FROM v_tx_received") + .unwrap(); + let mut rows = q.query(NO_PARAMS).unwrap(); + let mut row_count = 0; + while let Some(row) = rows.next().unwrap() { + row_count += 1; + let total: i64 = row.get(0).unwrap(); + let count: i64 = row.get(1).unwrap(); + let memo_count: i64 = row.get(2).unwrap(); + assert_eq!(total, 12); + assert_eq!(count, 2); + assert_eq!(memo_count, 2); + } + assert_eq!(row_count, 1); + + let mut q = db_data + .conn + .prepare("SELECT sent_total, sent_note_count, memo_count FROM v_tx_sent") + .unwrap(); + let mut rows = q.query(NO_PARAMS).unwrap(); + let mut row_count = 0; + while let Some(row) = rows.next().unwrap() { + row_count += 1; + let total: i64 = row.get(0).unwrap(); + let count: i64 = row.get(1).unwrap(); + let memo_count: i64 = row.get(2).unwrap(); + assert_eq!(total, 5); + assert_eq!(count, 3); + assert_eq!(memo_count, 1); + } + assert_eq!(row_count, 1); + + let mut q = db_data + .conn + .prepare("SELECT net_value, has_change, memo_count FROM v_transactions") + .unwrap(); + let mut rows = q.query(NO_PARAMS).unwrap(); + let mut row_count = 0; + while let Some(row) = rows.next().unwrap() { + row_count += 1; + let net_value: i64 = row.get(0).unwrap(); + let has_change: bool = row.get(1).unwrap(); + let memo_count: i64 = row.get(2).unwrap(); + assert_eq!(net_value, 0); + assert!(has_change); + assert_eq!(memo_count, 3); + } + assert_eq!(row_count, 1); + } +}