zcash_client_sqlite: Refactor `wallet::Account` to be a struct

This commit is contained in:
Jack Grigg 2024-03-13 18:33:28 +00:00
parent 634ebf51ef
commit bc6aa955ff
4 changed files with 195 additions and 204 deletions

View File

@ -100,7 +100,7 @@ pub mod error;
pub mod wallet;
use wallet::{
commitment_tree::{self, put_shard_roots},
Account, HdSeedAccount, SubtreeScanProgress,
AccountType, SubtreeScanProgress,
};
#[cfg(test)]
@ -307,16 +307,19 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for W
seed: &SecretVec<u8>,
) -> Result<bool, Self::Error> {
if let Some(account) = wallet::get_account(self, account_id)? {
if let Account::Zip32(hdaccount) = account {
let seed_fingerprint_match =
HdSeedFingerprint::from_seed(seed) == *hdaccount.hd_seed_fingerprint();
if let AccountType::Derived {
seed_fingerprint,
account_index,
} = account.kind
{
let seed_fingerprint_match = HdSeedFingerprint::from_seed(seed) == seed_fingerprint;
let usk = UnifiedSpendingKey::from_seed(
&self.params,
&seed.expose_secret()[..],
hdaccount.account_index(),
account_index,
)
.map_err(|_| SqliteClientError::KeyDerivationError(hdaccount.account_index()))?;
.map_err(|_| SqliteClientError::KeyDerivationError(account_index))?;
// Keys are not comparable with `Eq`, but addresses are, so we derive what should
// be equivalent addresses for each key and use those to check for key equality.
@ -326,7 +329,7 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for W
Ok(usk
.to_unified_full_viewing_key()
.default_address(ua_request)?
== hdaccount.ufvk().default_address(ua_request)?)
== account.default_address(ua_request)?)
},
)?;
@ -490,8 +493,8 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
birthday: AccountBirthday,
) -> Result<(AccountId, UnifiedSpendingKey), Self::Error> {
self.transactionally(|wdb| {
let seed_id = HdSeedFingerprint::from_seed(seed);
let account_index = wallet::max_zip32_account_index(wdb.conn.0, &seed_id)?
let seed_fingerprint = HdSeedFingerprint::from_seed(seed);
let account_index = wallet::max_zip32_account_index(wdb.conn.0, &seed_fingerprint)?
.map(|a| a.next().ok_or(SqliteClientError::AccountIdOutOfRange))
.transpose()?
.unwrap_or(zip32::AccountId::ZERO);
@ -501,8 +504,16 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
.map_err(|_| SqliteClientError::KeyDerivationError(account_index))?;
let ufvk = usk.to_unified_full_viewing_key();
let account = Account::Zip32(HdSeedAccount::new(seed_id, account_index, ufvk));
let account_id = wallet::add_account(wdb.conn.0, &wdb.params, account, birthday)?;
let account_id = wallet::add_account(
wdb.conn.0,
&wdb.params,
AccountType::Derived {
seed_fingerprint,
account_index,
},
wallet::ViewingKey::Full(Box::new(ufvk)),
birthday,
)?;
Ok((account_id, usk))
})

View File

@ -140,82 +140,71 @@ pub(crate) const BLOCK_SAPLING_FRONTIER_ABSENT: &[u8] = &[0x0];
/// This tracks the allowed values of the `account_type` column of the `accounts` table
/// and should not be made public.
enum AccountType {
Zip32,
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub(crate) enum AccountType {
/// An account derived from a known seed.
Derived {
seed_fingerprint: HdSeedFingerprint,
account_index: zip32::AccountId,
},
/// An account imported from a viewing key.
Imported,
}
impl TryFrom<u32> for AccountType {
type Error = ();
fn try_from(value: u32) -> Result<Self, Self::Error> {
match value {
0 => Ok(AccountType::Zip32),
1 => Ok(AccountType::Imported),
_ => Err(()),
}
fn parse_account_kind(
account_type: u32,
hd_seed_fingerprint: Option<[u8; 32]>,
hd_account_index: Option<u32>,
) -> Result<AccountType, SqliteClientError> {
match (account_type, hd_seed_fingerprint, hd_account_index) {
(0, Some(seed_fp), Some(account_index)) => Ok(AccountType::Derived {
seed_fingerprint: HdSeedFingerprint::from_bytes(seed_fp),
account_index: zip32::AccountId::try_from(account_index).map_err(|_| {
SqliteClientError::CorruptedData(
"ZIP-32 account ID from wallet DB is out of range.".to_string(),
)
})?,
}),
(1, None, None) => Ok(AccountType::Imported),
(0, None, None) | (1, Some(_), Some(_)) => Err(SqliteClientError::CorruptedData(
"Wallet DB account_type constraint violated".to_string(),
)),
(_, _, _) => Err(SqliteClientError::CorruptedData(
"Unrecognized account_type".to_string(),
)),
}
}
impl From<AccountType> for u32 {
fn from(value: AccountType) -> Self {
match value {
AccountType::Zip32 => 0,
AccountType::Imported => 1,
}
fn account_kind_code(value: AccountType) -> u32 {
match value {
AccountType::Derived { .. } => 0,
AccountType::Imported => 1,
}
}
/// Describes the key inputs and UFVK for an account that was derived from a ZIP-32 HD seed and account index.
/// The viewing key that an [`Account`] has available to it.
#[derive(Debug, Clone)]
pub(crate) struct HdSeedAccount(
HdSeedFingerprint,
zip32::AccountId,
Box<UnifiedFullViewingKey>,
);
impl HdSeedAccount {
pub fn new(
hd_seed_fingerprint: HdSeedFingerprint,
account_index: zip32::AccountId,
ufvk: UnifiedFullViewingKey,
) -> Self {
Self(hd_seed_fingerprint, account_index, Box::new(ufvk))
}
/// Returns the HD seed fingerprint for this account.
pub fn hd_seed_fingerprint(&self) -> &HdSeedFingerprint {
&self.0
}
/// Returns the ZIP-32 account index for this account.
pub fn account_index(&self) -> zip32::AccountId {
self.1
}
/// Returns the Unified Full Viewing Key for this account.
pub fn ufvk(&self) -> &UnifiedFullViewingKey {
&self.2
}
}
/// Represents an arbitrary account for which the seed and ZIP-32 account ID are not known
/// and may not have been involved in creating this account.
#[derive(Debug, Clone)]
pub(crate) enum ImportedAccount {
/// An account that was imported via its full viewing key.
pub(crate) enum ViewingKey {
/// A full viewing key.
///
/// This is available to derived accounts, as well as accounts directly imported as
/// full viewing keys.
Full(Box<UnifiedFullViewingKey>),
/// An account that was imported via its incoming viewing key.
/// An incoming viewing key.
///
/// Accounts that have this kind of viewing key cannot be used in wallet contexts,
/// because they are unable to maintain an accurate balance.
Incoming(Uivk),
}
/// Describes an account in terms of its UVK or ZIP-32 origins.
/// An account stored in a `zcash_client_sqlite` database.
#[derive(Debug, Clone)]
pub(crate) enum Account {
/// Inputs for a ZIP-32 HD account.
Zip32(HdSeedAccount),
/// Inputs for an imported account.
Imported(ImportedAccount),
pub(crate) struct Account {
account_id: AccountId,
pub(crate) kind: AccountType,
viewing_key: ViewingKey,
}
impl Account {
@ -228,10 +217,35 @@ impl Account {
&self,
request: UnifiedAddressRequest,
) -> Result<(UnifiedAddress, DiversifierIndex), AddressGenerationError> {
match &self.viewing_key {
ViewingKey::Full(ufvk) => ufvk.default_address(request),
ViewingKey::Incoming(_uivk) => todo!(),
}
}
}
impl zcash_client_backend::data_api::Account<AccountId> for Account {
fn id(&self) -> AccountId {
self.account_id
}
fn ufvk(&self) -> Option<&UnifiedFullViewingKey> {
self.viewing_key.ufvk()
}
}
impl ViewingKey {
fn ufvk(&self) -> Option<&UnifiedFullViewingKey> {
match self {
Account::Zip32(HdSeedAccount(_, _, ufvk)) => ufvk.default_address(request),
Account::Imported(ImportedAccount::Full(ufvk)) => ufvk.default_address(request),
Account::Imported(ImportedAccount::Incoming(_uivk)) => todo!(),
ViewingKey::Full(ufvk) => Some(ufvk),
ViewingKey::Incoming(_) => None,
}
}
fn uivk_str(&self, params: &impl Parameters) -> Result<String, SqliteClientError> {
match self {
ViewingKey::Full(ufvk) => ufvk_to_uivk(ufvk, params),
ViewingKey::Incoming(uivk) => Ok(uivk.encode(&params.network_type())),
}
}
}
@ -291,44 +305,6 @@ pub(crate) fn max_zip32_account_index(
)
}
struct AccountSqlValues<'a> {
account_type: u32,
hd_seed_fingerprint: Option<&'a [u8]>,
hd_account_index: Option<u32>,
ufvk: Option<&'a UnifiedFullViewingKey>,
uivk: String,
}
/// Returns (account_type, hd_seed_fingerprint, hd_account_index, ufvk, uivk) for a given account.
fn get_sql_values_for_account_parameters<'a, P: consensus::Parameters>(
account: &'a Account,
params: &P,
) -> Result<AccountSqlValues<'a>, SqliteClientError> {
Ok(match account {
Account::Zip32(hdaccount) => AccountSqlValues {
account_type: AccountType::Zip32.into(),
hd_seed_fingerprint: Some(hdaccount.hd_seed_fingerprint().as_bytes()),
hd_account_index: Some(hdaccount.account_index().into()),
ufvk: Some(hdaccount.ufvk()),
uivk: ufvk_to_uivk(hdaccount.ufvk(), params)?,
},
Account::Imported(ImportedAccount::Full(ufvk)) => AccountSqlValues {
account_type: AccountType::Imported.into(),
hd_seed_fingerprint: None,
hd_account_index: None,
ufvk: Some(ufvk),
uivk: ufvk_to_uivk(ufvk, params)?,
},
Account::Imported(ImportedAccount::Incoming(uivk)) => AccountSqlValues {
account_type: AccountType::Imported.into(),
hd_seed_fingerprint: None,
hd_account_index: None,
ufvk: None,
uivk: uivk.encode(&params.network_type()),
},
})
}
pub(crate) fn ufvk_to_uivk<P: consensus::Parameters>(
ufvk: &UnifiedFullViewingKey,
params: &P,
@ -357,20 +333,27 @@ pub(crate) fn ufvk_to_uivk<P: consensus::Parameters>(
pub(crate) fn add_account<P: consensus::Parameters>(
conn: &rusqlite::Transaction,
params: &P,
account: Account,
kind: AccountType,
viewing_key: ViewingKey,
birthday: AccountBirthday,
) -> Result<AccountId, SqliteClientError> {
let args = get_sql_values_for_account_parameters(&account, params)?;
let (hd_seed_fingerprint, hd_account_index) = match kind {
AccountType::Derived {
seed_fingerprint,
account_index,
} => (Some(seed_fingerprint), Some(account_index)),
AccountType::Imported => (None, None),
};
let orchard_item = args
.ufvk
let orchard_item = viewing_key
.ufvk()
.and_then(|ufvk| ufvk.orchard().map(|k| k.to_bytes()));
let sapling_item = args
.ufvk
let sapling_item = viewing_key
.ufvk()
.and_then(|ufvk| ufvk.sapling().map(|k| k.to_bytes()));
#[cfg(feature = "transparent-inputs")]
let transparent_item = args
.ufvk
let transparent_item = viewing_key
.ufvk()
.and_then(|ufvk| ufvk.transparent().map(|k| k.serialize()));
#[cfg(not(feature = "transparent-inputs"))]
let transparent_item: Option<Vec<u8>> = None;
@ -392,11 +375,11 @@ pub(crate) fn add_account<P: consensus::Parameters>(
RETURNING id;
"#,
named_params![
":account_type": args.account_type,
":hd_seed_fingerprint": args.hd_seed_fingerprint,
":hd_account_index": args.hd_account_index,
":ufvk": args.ufvk.map(|ufvk| ufvk.encode(params)),
":uivk": args.uivk,
":account_type": account_kind_code(kind),
":hd_seed_fingerprint": hd_seed_fingerprint.as_ref().map(|fp| fp.as_bytes()),
":hd_account_index": hd_account_index.map(u32::from),
":ufvk": viewing_key.ufvk().map(|ufvk| ufvk.encode(params)),
":uivk": viewing_key.uivk_str(params)?,
":orchard_fvk_item_cache": orchard_item,
":sapling_fvk_item_cache": sapling_item,
":p2pkh_fvk_item_cache": transparent_item,
@ -406,6 +389,12 @@ pub(crate) fn add_account<P: consensus::Parameters>(
|row| Ok(AccountId(row.get(0)?)),
)?;
let account = Account {
account_id,
kind,
viewing_key,
};
// If a birthday frontier is available, insert it into the note commitment tree. If the
// birthday frontier is the empty frontier, we don't need to do anything.
if let Some(frontier) = birthday.sapling_frontier().value() {
@ -754,18 +743,18 @@ pub(crate) fn get_account_for_ufvk<P: consensus::Parameters>(
],
|row| {
let account_id = row.get::<_, u32>(0).map(AccountId)?;
Ok((
account_id,
row.get::<_, Option<String>>(1)?
.map(|ufvk_str| UnifiedFullViewingKey::decode(params, &ufvk_str))
.transpose()
.map_err(|e| {
SqliteClientError::CorruptedData(format!(
"Could not decode unified full viewing key for account {:?}: {}",
account_id, e
))
})?,
))
// We looked up the account by FVK components, so the UFVK column must be
// non-null.
let ufvk_str: String = row.get(1)?;
let ufvk = UnifiedFullViewingKey::decode(params, &ufvk_str).map_err(|e| {
SqliteClientError::CorruptedData(format!(
"Could not decode unified full viewing key for account {:?}: {}",
account_id, e
))
})?;
Ok((account_id, Some(ufvk)))
},
)?
.collect::<Result<Vec<_>, _>>()?;
@ -785,7 +774,7 @@ pub(crate) fn get_seed_account<P: consensus::Parameters>(
conn: &rusqlite::Connection,
params: &P,
seed: &HdSeedFingerprint,
account_id: zip32::AccountId,
account_index: zip32::AccountId,
) -> Result<Option<(AccountId, Option<UnifiedFullViewingKey>)>, SqliteClientError> {
let mut stmt = conn.prepare(
"SELECT id, ufvk
@ -797,22 +786,23 @@ pub(crate) fn get_seed_account<P: consensus::Parameters>(
let mut accounts = stmt.query_and_then::<_, SqliteClientError, _, _>(
named_params![
":hd_seed_fingerprint": seed.as_bytes(),
":hd_account_index": u32::from(account_id),
":hd_account_index": u32::from(account_index),
],
|row| {
let account_id = row.get::<_, u32>(0).map(AccountId)?;
Ok((
account_id,
row.get::<_, Option<String>>(1)?
.map(|ufvk_str| UnifiedFullViewingKey::decode(params, &ufvk_str))
.transpose()
.map_err(|e| {
SqliteClientError::CorruptedData(format!(
"Could not decode unified full viewing key for account {:?}: {}",
account_id, e
))
})?,
))
let ufvk = match row.get::<_, Option<String>>(1)? {
None => Err(SqliteClientError::CorruptedData(format!(
"Missing unified full viewing key for derived account {:?}",
account_id,
))),
Some(ufvk_str) => UnifiedFullViewingKey::decode(params, &ufvk_str).map_err(|e| {
SqliteClientError::CorruptedData(format!(
"Could not decode unified full viewing key for account {:?}: {}",
account_id, e
))
}),
}?;
Ok((account_id, Some(ufvk)))
},
)?;
@ -1506,7 +1496,7 @@ pub(crate) fn get_account<C: Borrow<rusqlite::Connection>, P: Parameters>(
) -> Result<Option<Account>, SqliteClientError> {
let mut sql = db.conn.borrow().prepare_cached(
r#"
SELECT account_type, ufvk, uivk, hd_seed_fingerprint, hd_account_index
SELECT account_type, hd_seed_fingerprint, hd_account_index, ufvk, uivk
FROM accounts
WHERE id = :account_id
"#,
@ -1516,51 +1506,36 @@ pub(crate) fn get_account<C: Borrow<rusqlite::Connection>, P: Parameters>(
let row = result.next()?;
match row {
Some(row) => {
let account_type: AccountType =
row.get::<_, u32>("account_type")?.try_into().map_err(|_| {
SqliteClientError::CorruptedData("Unrecognized account_type".to_string())
})?;
let kind = parse_account_kind(
row.get("account_type")?,
row.get("hd_seed_fingerprint")?,
row.get("hd_account_index")?,
)?;
let ufvk_str: Option<String> = row.get("ufvk")?;
let ufvk = if let Some(ufvk_str) = ufvk_str {
Some(
let viewing_key = if let Some(ufvk_str) = ufvk_str {
ViewingKey::Full(Box::new(
UnifiedFullViewingKey::decode(&db.params, &ufvk_str[..])
.map_err(SqliteClientError::BadAccountData)?,
)
))
} else {
None
let uivk_str: String = row.get("uivk")?;
let (network, uivk) = Uivk::decode(&uivk_str).map_err(|e| {
SqliteClientError::CorruptedData(format!("Failure to decode UIVK: {e}"))
})?;
if network != db.params.network_type() {
return Err(SqliteClientError::CorruptedData(
"UIVK network type does not match wallet network type".to_string(),
));
}
ViewingKey::Incoming(uivk)
};
let uivk_str: String = row.get("uivk")?;
let (network, uivk) = Uivk::decode(&uivk_str).map_err(|e| {
SqliteClientError::CorruptedData(format!("Failure to decode UIVK: {e}"))
})?;
if network != db.params.network_type() {
return Err(SqliteClientError::CorruptedData(
"UIVK network type does not match wallet network type".to_string(),
));
}
match account_type {
AccountType::Zip32 => Ok(Some(Account::Zip32(HdSeedAccount::new(
HdSeedFingerprint::from_bytes(row.get("hd_seed_fingerprint")?),
zip32::AccountId::try_from(row.get::<_, u32>("hd_account_index")?).map_err(
|_| {
SqliteClientError::CorruptedData(
"ZIP-32 account ID from db is out of range.".to_string(),
)
},
)?,
ufvk.ok_or_else(|| {
SqliteClientError::CorruptedData(
"ZIP-32 account is missing a full viewing key".to_string(),
)
})?,
)))),
AccountType::Imported => Ok(Some(Account::Imported(if let Some(ufvk) = ufvk {
ImportedAccount::Full(Box::new(ufvk))
} else {
ImportedAccount::Incoming(uivk)
}))),
}
Ok(Some(Account {
account_id,
kind,
viewing_key,
}))
}
None => Ok(None),
}
@ -2725,7 +2700,7 @@ mod tests {
use crate::{
testing::{AddressType, BlockCache, TestBuilder, TestState},
wallet::{get_account, Account},
wallet::{get_account, AccountType},
AccountId,
};
@ -2877,8 +2852,8 @@ mod tests {
let expected_account_index = zip32::AccountId::try_from(0).unwrap();
assert_matches!(
account_parameters,
Account::Zip32(hdaccount) if hdaccount.account_index() == expected_account_index
account_parameters.kind,
AccountType::Derived{account_index, ..} if account_index == expected_account_index
);
}

View File

@ -1284,7 +1284,7 @@ mod tests {
fn account_produces_expected_ua_sequence() {
use zcash_client_backend::data_api::AccountBirthday;
use crate::wallet::{get_account, Account};
use crate::wallet::{get_account, AccountType};
let network = Network::MainNetwork;
let data_file = NamedTempFile::new().unwrap();
@ -1301,7 +1301,10 @@ mod tests {
.unwrap();
assert_matches!(
get_account(&db_data, account_id),
Ok(Some(Account::Zip32(hdaccount))) if hdaccount.account_index() == zip32::AccountId::ZERO
Ok(Some(account)) if matches!(
account.kind,
AccountType::Derived{account_index, ..} if account_index == zip32::AccountId::ZERO,
)
);
for tv in &test_vectors::UNIFIED[..3] {

View File

@ -1,6 +1,6 @@
use std::{collections::HashSet, rc::Rc};
use crate::wallet::{init::WalletMigrationError, ufvk_to_uivk, AccountType};
use crate::wallet::{account_kind_code, init::WalletMigrationError, ufvk_to_uivk, AccountType};
use rusqlite::{named_params, Transaction};
use schemer_rusqlite::RusqliteMigration;
use secrecy::{ExposeSecret, SecretVec};
@ -44,8 +44,11 @@ impl<P: consensus::Parameters> RusqliteMigration for Migration<P> {
type Error = WalletMigrationError;
fn up(&self, transaction: &Transaction) -> Result<(), WalletMigrationError> {
let account_type_zip32 = u32::from(AccountType::Zip32);
let account_type_imported = u32::from(AccountType::Imported);
let account_type_derived = account_kind_code(AccountType::Derived {
seed_fingerprint: HdSeedFingerprint::from_bytes([0; 32]),
account_index: zip32::AccountId::ZERO,
});
let account_type_imported = account_kind_code(AccountType::Imported);
transaction.execute_batch(
&format!(r#"
PRAGMA foreign_keys = OFF;
@ -53,7 +56,7 @@ impl<P: consensus::Parameters> RusqliteMigration for Migration<P> {
CREATE TABLE accounts_new (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
account_type INTEGER NOT NULL DEFAULT {account_type_zip32},
account_type INTEGER NOT NULL DEFAULT {account_type_derived},
hd_seed_fingerprint BLOB,
hd_account_index INTEGER,
ufvk TEXT,
@ -64,7 +67,7 @@ impl<P: consensus::Parameters> RusqliteMigration for Migration<P> {
birthday_height INTEGER NOT NULL,
recover_until_height INTEGER,
CHECK (
(account_type = {account_type_zip32} AND hd_seed_fingerprint IS NOT NULL AND hd_account_index IS NOT NULL AND ufvk IS NOT NULL)
(account_type = {account_type_derived} AND hd_seed_fingerprint IS NOT NULL AND hd_account_index IS NOT NULL AND ufvk IS NOT NULL)
OR
(account_type = {account_type_imported} AND hd_seed_fingerprint IS NULL AND hd_account_index IS NULL)
)
@ -85,7 +88,6 @@ impl<P: consensus::Parameters> RusqliteMigration for Migration<P> {
let mut rows = q.query([])?;
while let Some(row) = rows.next()? {
let account_index: u32 = row.get("account")?;
let account_type = u32::from(AccountType::Zip32);
let birthday_height: u32 = row.get("birthday_height")?;
let recover_until_height: Option<u32> = row.get("recover_until_height")?;
@ -142,7 +144,7 @@ impl<P: consensus::Parameters> RusqliteMigration for Migration<P> {
"#,
named_params![
":account_id": account_id,
":account_type": account_type,
":account_type": account_type_derived,
":seed_id": seed_id.as_bytes(),
":account_index": account_index,
":ufvk": ufvk,