Issue #244 Add Synchronizer support for shielding funds

This commit is contained in:
Francisco Gindre 2021-01-22 18:51:48 -03:00
parent 6db159c880
commit 81f4edfd55
7 changed files with 145 additions and 54 deletions

View File

@ -47,7 +47,16 @@ class GetUTXOsViewController: UIViewController {
KRProgressHUD.dismiss() KRProgressHUD.dismiss()
switch result { switch result {
case .success(let utxos): case .success(let utxos):
self?.messageLabel.text = "found \(utxos.count) UTXOs for address \(tAddr)" do {
let balance = try AppDelegate.shared.sharedSynchronizer.getUnshieldedBalance(address: tAddr)
self?.messageLabel.text = """
found \(utxos.count) UTXOs for address \(tAddr)
\(balance)
"""
} catch {
self?.messageLabel.text = "Error \(error)"
}
case .failure(let error): case .failure(let error):
self?.messageLabel.text = "Error \(error)" self?.messageLabel.text = "Error \(error)"
@ -110,3 +119,14 @@ extension GetUTXOsViewController: UITextFieldDelegate {
} }
} }
extension UnshieldedBalance {
var description: String {
"""
UnshieldedBalance:
confirmed: \(self.confirmed)
unconfirmed:\(self.unconfirmed)
"""
}
}

View File

@ -123,19 +123,30 @@ class UnspentTransactionOutputSQLDAO: UnspentTransactionOutputRepository {
} }
} }
func balance(address: String) throws -> Int { func balance(address: String, latestHeight: BlockHeight) throws -> UnshieldedBalance {
guard let sum = try dbProvider.connection().scalar( do {
let confirmed = try dbProvider.connection().scalar(
table.select(TableColumns.valueZat.sum)
.filter(TableColumns.address == address)
.filter(TableColumns.height <= latestHeight - ZcashSDK.DEFAULT_STALE_TOLERANCE)) ?? 0
let unconfirmed = try dbProvider.connection().scalar(
table.select(TableColumns.valueZat.sum) table.select(TableColumns.valueZat.sum)
.filter(TableColumns.address == address) .filter(TableColumns.address == address)) ?? 0
) else {
return TransparentBalance(confirmed: Int64(confirmed), unconfirmed: Int64(unconfirmed), address: address)
} catch {
throw StorageError.operationFailed throw StorageError.operationFailed
} }
return sum
} }
} }
struct TransparentBalance: UnshieldedBalance {
var confirmed: Int64
var unconfirmed: Int64
var address: String
}
class UTXORepositoryBuilder { class UTXORepositoryBuilder {
static func build(initializer: Initializer) throws -> UnspentTransactionOutputRepository { static func build(initializer: Initializer) throws -> UnspentTransactionOutputRepository {
let dao = UnspentTransactionOutputSQLDAO(dbProvider: SimpleConnectionProvider(path: initializer.cacheDbURL.path)) let dao = UnspentTransactionOutputSQLDAO(dbProvider: SimpleConnectionProvider(path: initializer.cacheDbURL.path))

View File

@ -7,11 +7,16 @@
import Foundation import Foundation
public protocol UnshieldedBalance {
var confirmed: Int64 { get set }
var unconfirmed: Int64 { get set }
}
protocol UnspentTransactionOutputRepository { protocol UnspentTransactionOutputRepository {
func getAll(address: String?) throws -> [UnspentTransactionOutputEntity] func getAll(address: String?) throws -> [UnspentTransactionOutputEntity]
func balance(address: String) throws -> Int func balance(address: String, latestHeight: BlockHeight) throws -> UnshieldedBalance
func store(utxos: [UnspentTransactionOutputEntity]) throws func store(utxos: [UnspentTransactionOutputEntity]) throws

View File

@ -233,7 +233,7 @@ extension LightWalletGRPCService: LightWalletService {
index: Int(reply.index), index: Int(reply.index),
script: reply.script, script: reply.script,
valueZat: Int(reply.valueZat), valueZat: Int(reply.valueZat),
height: Int(reply.valueZat) height: Int(reply.height)
) )
) )
} }

View File

@ -30,25 +30,14 @@ public enum ShieldFundsError: Error {
case shieldingFailed(underlyingError: Error) case shieldingFailed(underlyingError: Error)
} }
/** /**
Primary interface for interacting with the SDK. Defines the contract that specific Primary interface for interacting with the SDK. Defines the contract that specific
implementations like SdkSynchronizer fulfill. implementations like SdkSynchronizer fulfill.
*/ */
public protocol Synchronizer { public protocol Synchronizer {
/**
Starts this synchronizer within the given scope.
Implementations should leverage structured concurrency and
cancel all jobs when this scope completes.
*/
func start(retry: Bool) throws
/**
Stop this synchronizer. Implementations should ensure that calling this method cancels all
jobs that were created by this instance.
*/
func stop() throws
/** /**
Value representing the Status of this Synchronizer. As the status changes, a new Value representing the Status of this Synchronizer. As the status changes, a new
@ -63,6 +52,22 @@ public protocol Synchronizer {
*/ */
var progress: Float { get } var progress: Float { get }
/**
Starts this synchronizer within the given scope.
Implementations should leverage structured concurrency and
cancel all jobs when this scope completes.
*/
func start(retry: Bool) throws
/**
Stop this synchronizer. Implementations should ensure that calling this method cancels all
jobs that were created by this instance.
*/
func stop() throws
/** /**
Gets the address for the given account. Gets the address for the given account.
- Parameter accountIndex: the optional accountId whose address is of interest. By default, the first account is used. - Parameter accountIndex: the optional accountId whose address is of interest. By default, the first account is used.
@ -157,6 +162,15 @@ public protocol Synchronizer {
*/ */
func cachedUTXOs(address: String) throws -> [UnspentTransactionOutputEntity] func cachedUTXOs(address: String) throws -> [UnspentTransactionOutputEntity]
/**
gets the unshielded balance for the given address.
*/
func latestUnshieldedBalance(address: String, result: @escaping (Result<UnshieldedBalance,Error>) -> Void)
/**
gets the last stored unshielded balance
*/
func getUnshieldedBalance(address: String) throws -> UnshieldedBalance
} }
/** /**

View File

@ -355,16 +355,16 @@ public class SDKSynchronizer: Synchronizer {
do { do {
let tAddr = try derivationTool.deriveTransparentAddressFromPrivateKey(transparentSecretKey) let tAddr = try derivationTool.deriveTransparentAddressFromPrivateKey(transparentSecretKey)
let tBalance = try utxoRepository.balance(address: tAddr) let tBalance = try utxoRepository.balance(address: tAddr, latestHeight: self.latestDownloadedHeight())
guard tBalance > Self.shieldingThreshold else { guard tBalance.confirmed > Self.shieldingThreshold else {
resultBlock(.failure(ShieldFundsError.insuficientTransparentFunds)) resultBlock(.failure(ShieldFundsError.insuficientTransparentFunds))
return return
} }
let vk = try derivationTool.deriveViewingKey(spendingKey: spendingKey) let vk = try derivationTool.deriveViewingKey(spendingKey: spendingKey)
let zAddr = try derivationTool.deriveShieldedAddress(viewingKey: vk) let zAddr = try derivationTool.deriveShieldedAddress(viewingKey: vk)
let shieldingSpend = try transactionManager.initSpend(zatoshi: tBalance, toAddress: zAddr, memo: memo, from: 0) let shieldingSpend = try transactionManager.initSpend(zatoshi: Int(tBalance.confirmed), toAddress: zAddr, memo: memo, from: 0)
transactionManager.encodeShieldingTransaction(spendingKey: spendingKey, tsk: transparentSecretKey, pendingTransaction: shieldingSpend) {[weak self] (result) in transactionManager.encodeShieldingTransaction(spendingKey: spendingKey, tsk: transparentSecretKey, pendingTransaction: shieldingSpend) {[weak self] (result) in
guard let self = self else { return } guard let self = self else { return }
@ -494,6 +494,38 @@ public class SDKSynchronizer: Synchronizer {
try utxoRepository.getAll(address: address) try utxoRepository.getAll(address: address)
} }
/**
gets the unshielded balance for the given address.
*/
public func latestUnshieldedBalance(address: String, result: @escaping (Result<UnshieldedBalance,Error>) -> Void) {
latestUTXOs(address: address, result: { [weak self] (r) in
guard let self = self else { return }
switch r {
case .success:
do {
result(.success(try self.utxoRepository.balance(address: address, latestHeight: try self.latestDownloadedHeight())))
} catch {
result(.failure(SynchronizerError.uncategorized(underlyingError: error)))
}
case .failure(let e):
result(.failure(SynchronizerError.generalError(message: "\(e)")))
}
})
}
/**
gets the last stored unshielded balance
*/
public func getUnshieldedBalance(address: String) throws -> UnshieldedBalance {
do {
let latestHeight = try self.latestDownloadedHeight()
return try utxoRepository.balance(address: address, latestHeight: latestHeight)
} catch {
throw SynchronizerError.uncategorized(underlyingError: error)
}
}
// MARK: notify state // MARK: notify state
private func notify(progress: Float, height: BlockHeight) { private func notify(progress: Float, height: BlockHeight) {
NotificationCenter.default.post(name: Notification.Name.synchronizerProgressUpdated, object: self, userInfo: [ NotificationCenter.default.post(name: Notification.Name.synchronizerProgressUpdated, object: self, userInfo: [

View File

@ -26,11 +26,9 @@ use zcash_client_sqlite::{
error::SqliteClientError, error::SqliteClientError,
wallet::{ wallet::{
init::{init_accounts_table, init_blocks_table, init_wallet_db}, init::{init_accounts_table, init_blocks_table, init_wallet_db},
get_balance,
get_balance_at,
}, },
BlockDB, NoteId, WalletDB, BlockDB, NoteId, WalletDB,
chain::{UnspentTransactionOutput, get_all_utxos, get_confirmed_utxos_for_address} chain::get_confirmed_utxos_for_address,
}; };
use zcash_primitives::{ use zcash_primitives::{
block::BlockHash, block::BlockHash,
@ -97,13 +95,13 @@ pub const NETWORK: MainNetwork = MAIN_NETWORK;
pub const NETWORK: TestNetwork = TEST_NETWORK; pub const NETWORK: TestNetwork = TEST_NETWORK;
fn wallet_db<P: consensus::Parameters>(params: &P,db_data: *const u8, fn wallet_db<P: consensus::Parameters>(params: P,db_data: *const u8,
db_data_len: usize) -> Result<WalletDB<P>, failure::Error> { db_data_len: usize) -> Result<WalletDB<P>, failure::Error> {
let db_data = Path::new(OsStr::from_bytes(unsafe { let db_data = Path::new(OsStr::from_bytes(unsafe {
slice::from_raw_parts(db_data, db_data_len) slice::from_raw_parts(db_data, db_data_len)
})); }));
WalletDB::for_path(db_data, *params) WalletDB::for_path(db_data, params)
.map_err(|e| format_err!("Error opening wallet database connection: {}", e)) .map_err(|e| format_err!("Error opening wallet database connection: {}", e))
} }
@ -168,7 +166,7 @@ pub extern "C" fn zcashlc_init_accounts_table(
capacity_ret: *mut usize, capacity_ret: *mut usize,
) -> *mut *mut c_char { ) -> *mut *mut c_char {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let seed = unsafe { slice::from_raw_parts(seed, seed_len) }; let seed = unsafe { slice::from_raw_parts(seed, seed_len) };
let accounts = if accounts >= 0 { let accounts = if accounts >= 0 {
accounts as u32 accounts as u32
@ -214,7 +212,7 @@ pub extern "C" fn zcashlc_init_accounts_table_with_keys(
extfvks_len: usize, extfvks_len: usize,
) -> bool { ) -> bool {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let extfvks = unsafe { std::slice::from_raw_parts(extfvks, extfvks_len) let extfvks = unsafe { std::slice::from_raw_parts(extfvks, extfvks_len)
.into_iter() .into_iter()
@ -423,7 +421,7 @@ pub extern "C" fn zcashlc_init_blocks_table(
sapling_tree_hex: *const c_char, sapling_tree_hex: *const c_char,
) -> i32 { ) -> i32 {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let hash = { let hash = {
let mut hash = hex::decode(unsafe { CStr::from_ptr(hash_hex) }.to_str()?).unwrap(); let mut hash = hex::decode(unsafe { CStr::from_ptr(hash_hex) }.to_str()?).unwrap();
hash.reverse(); hash.reverse();
@ -450,7 +448,7 @@ pub extern "C" fn zcashlc_get_address(
account: i32, account: i32,
) -> *mut c_char { ) -> *mut c_char {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let account = if account >= 0 { let account = if account >= 0 {
account as u32 account as u32
} else { } else {
@ -537,7 +535,7 @@ fn is_valid_transparent_address(address: &str) -> bool {
#[no_mangle] #[no_mangle]
pub extern "C" fn zcashlc_get_balance(db_data: *const u8, db_data_len: usize, account: i32) -> i64 { pub extern "C" fn zcashlc_get_balance(db_data: *const u8, db_data_len: usize, account: i32) -> i64 {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let account = if account >= 0 { let account = if account >= 0 {
account as u32 account as u32
@ -546,10 +544,25 @@ pub extern "C" fn zcashlc_get_balance(db_data: *const u8, db_data_len: usize, ac
}; };
let account = AccountId(account); let account = AccountId(account);
match db_data.get_balance(account) { // match db_data.get_balance(account) {
Ok(balance) => Ok(balance.into()), // Ok(balance) => Ok(balance.into()),
Err(e) => Err(format_err!("Error while fetching balance: {}", e)), // Err(e) => Err(format_err!("Error while fetching balance: {}", e)),
} // }
(&db_data)
.get_target_and_anchor_heights()
.map_err(|e| format_err!("Error while fetching anchor height: {}", e))
.and_then(|opt_anchor| {
opt_anchor
.map(|(h, _)| h)
.ok_or(format_err!("height not available; scan required."))
})
.and_then(|height| {
(&db_data)
.get_balance_at(account, height)
.map_err(|e| format_err!("Error while fetching verified balance: {}", e))
})
.map(|amount| amount.into())
}); });
unwrap_exc_or(res, -1) unwrap_exc_or(res, -1)
} }
@ -563,7 +576,7 @@ pub extern "C" fn zcashlc_get_verified_balance(
account: i32, account: i32,
) -> i64 { ) -> i64 {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let account = if account >= 0 { let account = if account >= 0 {
account as u32 account as u32
} else { } else {
@ -601,7 +614,7 @@ pub extern "C" fn zcashlc_get_received_memo_as_utf8(
id_note: i64, id_note: i64,
) -> *mut c_char { ) -> *mut c_char {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let memo = match (&db_data).get_received_memo_as_utf8(NoteId(id_note)) { let memo = match (&db_data).get_received_memo_as_utf8(NoteId(id_note)) {
Ok(memo) => memo.unwrap_or_default(), Ok(memo) => memo.unwrap_or_default(),
@ -626,7 +639,7 @@ pub extern "C" fn zcashlc_get_sent_memo_as_utf8(
id_note: i64, id_note: i64,
) -> *mut c_char { ) -> *mut c_char {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let memo = (&db_data) let memo = (&db_data)
.get_sent_memo_as_utf8(NoteId(id_note)) .get_sent_memo_as_utf8(NoteId(id_note))
@ -663,7 +676,7 @@ pub extern "C" fn zcashlc_validate_combined_chain(
) -> i32 { ) -> i32 {
let res = catch_panic(|| { let res = catch_panic(|| {
let block_db = block_db(db_cache, db_cache_len)?; let block_db = block_db(db_cache, db_cache_len)?;
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let validate_from = (&db_data) let validate_from = (&db_data)
.get_max_height_hash() .get_max_height_hash()
@ -698,7 +711,7 @@ pub extern "C" fn zcashlc_rewind_to_height(
height: i32, height: i32,
) -> i32 { ) -> i32 {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let mut update_ops = (&db_data) let mut update_ops = (&db_data)
.get_update_ops() .get_update_ops()
@ -737,7 +750,8 @@ pub extern "C" fn zcashlc_scan_blocks(
) -> i32 { ) -> i32 {
let res = catch_panic(|| { let res = catch_panic(|| {
let block_db = block_db(db_cache, db_cache_len)?; let block_db = block_db(db_cache, db_cache_len)?;
let mut db_data = wallet_db(&NETWORK, db_data, db_data_len)?.get_update_ops()?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let mut db_data = db_data.get_update_ops()?;
match scan_cached_blocks(&NETWORK, &block_db, &mut db_data, None) { match scan_cached_blocks(&NETWORK, &block_db, &mut db_data, None) {
Ok(()) => Ok(1), Ok(()) => Ok(1),
@ -755,7 +769,8 @@ pub extern "C" fn zcashlc_decrypt_and_store_transaction(
tx_len: usize, tx_len: usize,
) -> i32 { ) -> i32 {
let res = catch_panic(|| { let res = catch_panic(|| {
let mut db_data = wallet_db(&NETWORK, db_data, db_data_len)?.get_update_ops()?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let mut db_data = db_data.get_update_ops()?;
let tx_bytes = unsafe { slice::from_raw_parts(tx, tx_len) }; let tx_bytes = unsafe { slice::from_raw_parts(tx, tx_len) };
let tx = Transaction::read(&tx_bytes[..])?; let tx = Transaction::read(&tx_bytes[..])?;
@ -791,7 +806,8 @@ pub extern "C" fn zcashlc_create_to_address(
) -> i64 { ) -> i64 {
let res = catch_panic(|| { let res = catch_panic(|| {
let mut db_data = wallet_db(&NETWORK, db_data, db_data_len)?.get_update_ops()?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let mut db_data = db_data.get_update_ops()?;
let account = if account >= 0 { let account = if account >= 0 {
account as u32 account as u32
} else { } else {
@ -1083,13 +1099,6 @@ fn shield_funds<P: consensus::Parameters>(
// derive the corresponding t-address // derive the corresponding t-address
let t_addr_str = derive_transparent_address_from_secret_key(sk); let t_addr_str = derive_transparent_address_from_secret_key(sk);
let t_addr = match RecipientAddress::decode(&NETWORK, &t_addr_str) {
Some(to) => to,
None => {
return Err(format_err!("PaymentAddress is for the wrong network"));
},
};
let extsk = let extsk =
match decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &extsk) match decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &extsk)
{ {
@ -1251,7 +1260,7 @@ pub extern "C" fn zcashlc_shield_funds(
) -> i64 { ) -> i64 {
let res = catch_panic(|| { let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?; let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let db_cache = block_db(db_cache, db_cache_len)?; let db_cache = block_db(db_cache, db_cache_len)?;
let account = if account >= 0 { let account = if account >= 0 {
account as u32 account as u32