Issue #244 Add Synchronizer support for shielding funds

This commit is contained in:
Francisco Gindre 2021-01-22 18:51:48 -03:00
parent 6db159c880
commit 81f4edfd55
7 changed files with 145 additions and 54 deletions

View File

@ -47,7 +47,16 @@ class GetUTXOsViewController: UIViewController {
KRProgressHUD.dismiss()
switch result {
case .success(let utxos):
self?.messageLabel.text = "found \(utxos.count) UTXOs for address \(tAddr)"
do {
let balance = try AppDelegate.shared.sharedSynchronizer.getUnshieldedBalance(address: tAddr)
self?.messageLabel.text = """
found \(utxos.count) UTXOs for address \(tAddr)
\(balance)
"""
} catch {
self?.messageLabel.text = "Error \(error)"
}
case .failure(let error):
self?.messageLabel.text = "Error \(error)"
@ -110,3 +119,14 @@ extension GetUTXOsViewController: UITextFieldDelegate {
}
}
extension UnshieldedBalance {
var description: String {
"""
UnshieldedBalance:
confirmed: \(self.confirmed)
unconfirmed:\(self.unconfirmed)
"""
}
}

View File

@ -123,19 +123,30 @@ class UnspentTransactionOutputSQLDAO: UnspentTransactionOutputRepository {
}
}
func balance(address: String) throws -> Int {
func balance(address: String, latestHeight: BlockHeight) throws -> UnshieldedBalance {
guard let sum = try dbProvider.connection().scalar(
do {
let confirmed = try dbProvider.connection().scalar(
table.select(TableColumns.valueZat.sum)
.filter(TableColumns.address == address)
.filter(TableColumns.height <= latestHeight - ZcashSDK.DEFAULT_STALE_TOLERANCE)) ?? 0
let unconfirmed = try dbProvider.connection().scalar(
table.select(TableColumns.valueZat.sum)
.filter(TableColumns.address == address)
) else {
.filter(TableColumns.address == address)) ?? 0
return TransparentBalance(confirmed: Int64(confirmed), unconfirmed: Int64(unconfirmed), address: address)
} catch {
throw StorageError.operationFailed
}
return sum
}
}
struct TransparentBalance: UnshieldedBalance {
var confirmed: Int64
var unconfirmed: Int64
var address: String
}
class UTXORepositoryBuilder {
static func build(initializer: Initializer) throws -> UnspentTransactionOutputRepository {
let dao = UnspentTransactionOutputSQLDAO(dbProvider: SimpleConnectionProvider(path: initializer.cacheDbURL.path))

View File

@ -7,11 +7,16 @@
import Foundation
public protocol UnshieldedBalance {
var confirmed: Int64 { get set }
var unconfirmed: Int64 { get set }
}
protocol UnspentTransactionOutputRepository {
func getAll(address: String?) throws -> [UnspentTransactionOutputEntity]
func balance(address: String) throws -> Int
func balance(address: String, latestHeight: BlockHeight) throws -> UnshieldedBalance
func store(utxos: [UnspentTransactionOutputEntity]) throws

View File

@ -233,7 +233,7 @@ extension LightWalletGRPCService: LightWalletService {
index: Int(reply.index),
script: reply.script,
valueZat: Int(reply.valueZat),
height: Int(reply.valueZat)
height: Int(reply.height)
)
)
}

View File

@ -30,25 +30,14 @@ public enum ShieldFundsError: Error {
case shieldingFailed(underlyingError: Error)
}
/**
Primary interface for interacting with the SDK. Defines the contract that specific
implementations like SdkSynchronizer fulfill.
*/
public protocol Synchronizer {
/**
Starts this synchronizer within the given scope.
Implementations should leverage structured concurrency and
cancel all jobs when this scope completes.
*/
func start(retry: Bool) throws
/**
Stop this synchronizer. Implementations should ensure that calling this method cancels all
jobs that were created by this instance.
*/
func stop() throws
/**
Value representing the Status of this Synchronizer. As the status changes, a new
@ -63,6 +52,22 @@ public protocol Synchronizer {
*/
var progress: Float { get }
/**
Starts this synchronizer within the given scope.
Implementations should leverage structured concurrency and
cancel all jobs when this scope completes.
*/
func start(retry: Bool) throws
/**
Stop this synchronizer. Implementations should ensure that calling this method cancels all
jobs that were created by this instance.
*/
func stop() throws
/**
Gets the address for the given account.
- Parameter accountIndex: the optional accountId whose address is of interest. By default, the first account is used.
@ -157,6 +162,15 @@ public protocol Synchronizer {
*/
func cachedUTXOs(address: String) throws -> [UnspentTransactionOutputEntity]
/**
gets the unshielded balance for the given address.
*/
func latestUnshieldedBalance(address: String, result: @escaping (Result<UnshieldedBalance,Error>) -> Void)
/**
gets the last stored unshielded balance
*/
func getUnshieldedBalance(address: String) throws -> UnshieldedBalance
}
/**

View File

@ -355,16 +355,16 @@ public class SDKSynchronizer: Synchronizer {
do {
let tAddr = try derivationTool.deriveTransparentAddressFromPrivateKey(transparentSecretKey)
let tBalance = try utxoRepository.balance(address: tAddr)
let tBalance = try utxoRepository.balance(address: tAddr, latestHeight: self.latestDownloadedHeight())
guard tBalance > Self.shieldingThreshold else {
guard tBalance.confirmed > Self.shieldingThreshold else {
resultBlock(.failure(ShieldFundsError.insuficientTransparentFunds))
return
}
let vk = try derivationTool.deriveViewingKey(spendingKey: spendingKey)
let zAddr = try derivationTool.deriveShieldedAddress(viewingKey: vk)
let shieldingSpend = try transactionManager.initSpend(zatoshi: tBalance, toAddress: zAddr, memo: memo, from: 0)
let shieldingSpend = try transactionManager.initSpend(zatoshi: Int(tBalance.confirmed), toAddress: zAddr, memo: memo, from: 0)
transactionManager.encodeShieldingTransaction(spendingKey: spendingKey, tsk: transparentSecretKey, pendingTransaction: shieldingSpend) {[weak self] (result) in
guard let self = self else { return }
@ -494,6 +494,38 @@ public class SDKSynchronizer: Synchronizer {
try utxoRepository.getAll(address: address)
}
/**
gets the unshielded balance for the given address.
*/
public func latestUnshieldedBalance(address: String, result: @escaping (Result<UnshieldedBalance,Error>) -> Void) {
latestUTXOs(address: address, result: { [weak self] (r) in
guard let self = self else { return }
switch r {
case .success:
do {
result(.success(try self.utxoRepository.balance(address: address, latestHeight: try self.latestDownloadedHeight())))
} catch {
result(.failure(SynchronizerError.uncategorized(underlyingError: error)))
}
case .failure(let e):
result(.failure(SynchronizerError.generalError(message: "\(e)")))
}
})
}
/**
gets the last stored unshielded balance
*/
public func getUnshieldedBalance(address: String) throws -> UnshieldedBalance {
do {
let latestHeight = try self.latestDownloadedHeight()
return try utxoRepository.balance(address: address, latestHeight: latestHeight)
} catch {
throw SynchronizerError.uncategorized(underlyingError: error)
}
}
// MARK: notify state
private func notify(progress: Float, height: BlockHeight) {
NotificationCenter.default.post(name: Notification.Name.synchronizerProgressUpdated, object: self, userInfo: [

View File

@ -26,11 +26,9 @@ use zcash_client_sqlite::{
error::SqliteClientError,
wallet::{
init::{init_accounts_table, init_blocks_table, init_wallet_db},
get_balance,
get_balance_at,
},
BlockDB, NoteId, WalletDB,
chain::{UnspentTransactionOutput, get_all_utxos, get_confirmed_utxos_for_address}
chain::get_confirmed_utxos_for_address,
};
use zcash_primitives::{
block::BlockHash,
@ -97,13 +95,13 @@ pub const NETWORK: MainNetwork = MAIN_NETWORK;
pub const NETWORK: TestNetwork = TEST_NETWORK;
fn wallet_db<P: consensus::Parameters>(params: &P,db_data: *const u8,
fn wallet_db<P: consensus::Parameters>(params: P,db_data: *const u8,
db_data_len: usize) -> Result<WalletDB<P>, failure::Error> {
let db_data = Path::new(OsStr::from_bytes(unsafe {
slice::from_raw_parts(db_data, db_data_len)
}));
WalletDB::for_path(db_data, *params)
WalletDB::for_path(db_data, params)
.map_err(|e| format_err!("Error opening wallet database connection: {}", e))
}
@ -168,7 +166,7 @@ pub extern "C" fn zcashlc_init_accounts_table(
capacity_ret: *mut usize,
) -> *mut *mut c_char {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let seed = unsafe { slice::from_raw_parts(seed, seed_len) };
let accounts = if accounts >= 0 {
accounts as u32
@ -214,7 +212,7 @@ pub extern "C" fn zcashlc_init_accounts_table_with_keys(
extfvks_len: usize,
) -> bool {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let extfvks = unsafe { std::slice::from_raw_parts(extfvks, extfvks_len)
.into_iter()
@ -423,7 +421,7 @@ pub extern "C" fn zcashlc_init_blocks_table(
sapling_tree_hex: *const c_char,
) -> i32 {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let hash = {
let mut hash = hex::decode(unsafe { CStr::from_ptr(hash_hex) }.to_str()?).unwrap();
hash.reverse();
@ -450,7 +448,7 @@ pub extern "C" fn zcashlc_get_address(
account: i32,
) -> *mut c_char {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let account = if account >= 0 {
account as u32
} else {
@ -537,7 +535,7 @@ fn is_valid_transparent_address(address: &str) -> bool {
#[no_mangle]
pub extern "C" fn zcashlc_get_balance(db_data: *const u8, db_data_len: usize, account: i32) -> i64 {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let account = if account >= 0 {
account as u32
@ -546,10 +544,25 @@ pub extern "C" fn zcashlc_get_balance(db_data: *const u8, db_data_len: usize, ac
};
let account = AccountId(account);
match db_data.get_balance(account) {
Ok(balance) => Ok(balance.into()),
Err(e) => Err(format_err!("Error while fetching balance: {}", e)),
}
// match db_data.get_balance(account) {
// Ok(balance) => Ok(balance.into()),
// Err(e) => Err(format_err!("Error while fetching balance: {}", e)),
// }
(&db_data)
.get_target_and_anchor_heights()
.map_err(|e| format_err!("Error while fetching anchor height: {}", e))
.and_then(|opt_anchor| {
opt_anchor
.map(|(h, _)| h)
.ok_or(format_err!("height not available; scan required."))
})
.and_then(|height| {
(&db_data)
.get_balance_at(account, height)
.map_err(|e| format_err!("Error while fetching verified balance: {}", e))
})
.map(|amount| amount.into())
});
unwrap_exc_or(res, -1)
}
@ -563,7 +576,7 @@ pub extern "C" fn zcashlc_get_verified_balance(
account: i32,
) -> i64 {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let account = if account >= 0 {
account as u32
} else {
@ -601,7 +614,7 @@ pub extern "C" fn zcashlc_get_received_memo_as_utf8(
id_note: i64,
) -> *mut c_char {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let memo = match (&db_data).get_received_memo_as_utf8(NoteId(id_note)) {
Ok(memo) => memo.unwrap_or_default(),
@ -626,7 +639,7 @@ pub extern "C" fn zcashlc_get_sent_memo_as_utf8(
id_note: i64,
) -> *mut c_char {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let memo = (&db_data)
.get_sent_memo_as_utf8(NoteId(id_note))
@ -663,7 +676,7 @@ pub extern "C" fn zcashlc_validate_combined_chain(
) -> i32 {
let res = catch_panic(|| {
let block_db = block_db(db_cache, db_cache_len)?;
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let validate_from = (&db_data)
.get_max_height_hash()
@ -698,7 +711,7 @@ pub extern "C" fn zcashlc_rewind_to_height(
height: i32,
) -> i32 {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let mut update_ops = (&db_data)
.get_update_ops()
@ -737,7 +750,8 @@ pub extern "C" fn zcashlc_scan_blocks(
) -> i32 {
let res = catch_panic(|| {
let block_db = block_db(db_cache, db_cache_len)?;
let mut db_data = wallet_db(&NETWORK, db_data, db_data_len)?.get_update_ops()?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let mut db_data = db_data.get_update_ops()?;
match scan_cached_blocks(&NETWORK, &block_db, &mut db_data, None) {
Ok(()) => Ok(1),
@ -755,7 +769,8 @@ pub extern "C" fn zcashlc_decrypt_and_store_transaction(
tx_len: usize,
) -> i32 {
let res = catch_panic(|| {
let mut db_data = wallet_db(&NETWORK, db_data, db_data_len)?.get_update_ops()?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let mut db_data = db_data.get_update_ops()?;
let tx_bytes = unsafe { slice::from_raw_parts(tx, tx_len) };
let tx = Transaction::read(&tx_bytes[..])?;
@ -791,7 +806,8 @@ pub extern "C" fn zcashlc_create_to_address(
) -> i64 {
let res = catch_panic(|| {
let mut db_data = wallet_db(&NETWORK, db_data, db_data_len)?.get_update_ops()?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let mut db_data = db_data.get_update_ops()?;
let account = if account >= 0 {
account as u32
} else {
@ -1083,13 +1099,6 @@ fn shield_funds<P: consensus::Parameters>(
// derive the corresponding t-address
let t_addr_str = derive_transparent_address_from_secret_key(sk);
let t_addr = match RecipientAddress::decode(&NETWORK, &t_addr_str) {
Some(to) => to,
None => {
return Err(format_err!("PaymentAddress is for the wrong network"));
},
};
let extsk =
match decode_extended_spending_key(NETWORK.hrp_sapling_extended_spending_key(), &extsk)
{
@ -1251,7 +1260,7 @@ pub extern "C" fn zcashlc_shield_funds(
) -> i64 {
let res = catch_panic(|| {
let db_data = wallet_db(&NETWORK, db_data, db_data_len)?;
let db_data = wallet_db(NETWORK, db_data, db_data_len)?;
let db_cache = block_db(db_cache, db_cache_len)?;
let account = if account >= 0 {
account as u32