diff --git a/qa/rpc-tests/wallet.py b/qa/rpc-tests/wallet.py index c4c14b2eb..83d4cf4c5 100755 --- a/qa/rpc-tests/wallet.py +++ b/qa/rpc-tests/wallet.py @@ -335,8 +335,6 @@ class WalletTest (BitcoinTestFramework): mytxid = wait_and_assert_operationid_status(self.nodes[2], self.nodes[2].z_sendmany(mytaddr, recipients)) - self.sync_all() - self.nodes[2].generate(1) self.sync_all() # check balances @@ -344,10 +342,18 @@ class WalletTest (BitcoinTestFramework): zsendmanyfee = Decimal('0.0001') node2utxobalance = Decimal('23.998') - zsendmanynotevalue - zsendmanyfee + # check shielded balance status with getwalletinfo + wallet_info = self.nodes[2].getwalletinfo() + assert_equal(Decimal(wallet_info["shielded_unconfirmed_balance"]), zsendmanynotevalue) + assert_equal(Decimal(wallet_info["shielded_balance"]), Decimal('0.0')) + + self.nodes[2].generate(1) + self.sync_all() + assert_equal(self.nodes[2].getbalance(), node2utxobalance) assert_equal(self.nodes[2].getbalance("*"), node2utxobalance) - # check zaddr balance + # check zaddr balance with z_getbalance assert_equal(self.nodes[2].z_getbalance(myzaddr), zsendmanynotevalue) # check via z_gettotalbalance @@ -356,6 +362,11 @@ class WalletTest (BitcoinTestFramework): assert_equal(Decimal(resp["private"]), zsendmanynotevalue) assert_equal(Decimal(resp["total"]), node2utxobalance + zsendmanynotevalue) + # check confirmed shielded balance with getwalletinfo + wallet_info = self.nodes[2].getwalletinfo() + assert_equal(Decimal(wallet_info["shielded_unconfirmed_balance"]), Decimal('0.0')) + assert_equal(Decimal(wallet_info["shielded_balance"]), zsendmanynotevalue) + # there should be at least one joinsplit mytxdetails = self.nodes[2].gettransaction(mytxid) myvjoinsplits = mytxdetails["vjoinsplit"] diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index 29b9cbcd4..8274d5c37 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -2338,6 +2338,8 @@ UniValue settxfee(const UniValue& params, bool fHelp) return true; } +CAmount getBalanceZaddr(std::string address, int minDepth = 1, int maxDepth = INT_MAX, bool ignoreUnspendable=true); + UniValue getwalletinfo(const UniValue& params, bool fHelp) { if (!EnsureWalletIsAvailable(fHelp)) @@ -2353,6 +2355,8 @@ UniValue getwalletinfo(const UniValue& params, bool fHelp) " \"balance\": xxxxxxx, (numeric) the total confirmed transparent balance of the wallet in " + CURRENCY_UNIT + "\n" " \"unconfirmed_balance\": xxx, (numeric) the total unconfirmed transparent balance of the wallet in " + CURRENCY_UNIT + "\n" " \"immature_balance\": xxxxxx, (numeric) the total immature transparent balance of the wallet in " + CURRENCY_UNIT + "\n" + " \"shielded_balance\": xxxxxxx, (numeric) the total confirmed shielded balance of the wallet in " + CURRENCY_UNIT + "\n" + " \"shielded_unconfirmed_balance\": xxx, (numeric) the total unconfirmed shielded balance of the wallet in " + CURRENCY_UNIT + "\n" " \"txcount\": xxxxxxx, (numeric) the total number of transactions in the wallet\n" " \"keypoololdest\": xxxxxx, (numeric) the timestamp (seconds since GMT epoch) of the oldest pre-generated key in the key pool\n" " \"keypoolsize\": xxxx, (numeric) how many new keys are pre-generated\n" @@ -2372,6 +2376,8 @@ UniValue getwalletinfo(const UniValue& params, bool fHelp) obj.pushKV("balance", ValueFromAmount(pwalletMain->GetBalance())); obj.pushKV("unconfirmed_balance", ValueFromAmount(pwalletMain->GetUnconfirmedBalance())); obj.pushKV("immature_balance", ValueFromAmount(pwalletMain->GetImmatureBalance())); + obj.pushKV("shielded_balance", FormatMoney(getBalanceZaddr("", 1, INT_MAX))); + obj.pushKV("shielded_unconfirmed_balance", FormatMoney(getBalanceZaddr("", 0, 0))); obj.pushKV("txcount", (int)pwalletMain->mapWallet.size()); obj.pushKV("keypoololdest", pwalletMain->GetOldestKeyPoolTime()); obj.pushKV("keypoolsize", (int)pwalletMain->GetKeyPoolSize()); @@ -2636,11 +2642,11 @@ UniValue z_listunspent(const UniValue& params, bool fHelp) // User did not provide zaddrs, so use default i.e. all addresses std::set sproutzaddrs = {}; pwalletMain->GetSproutPaymentAddresses(sproutzaddrs); - + // Sapling support std::set saplingzaddrs = {}; pwalletMain->GetSaplingPaymentAddresses(saplingzaddrs); - + zaddrs.insert(sproutzaddrs.begin(), sproutzaddrs.end()); zaddrs.insert(saplingzaddrs.begin(), saplingzaddrs.end()); } @@ -2652,7 +2658,7 @@ UniValue z_listunspent(const UniValue& params, bool fHelp) std::vector saplingEntries; pwalletMain->GetFilteredNotes(sproutEntries, saplingEntries, zaddrs, nMinDepth, nMaxDepth, true, !fIncludeWatchonly, false); std::set> nullifierSet = pwalletMain->GetNullifiersForAddresses(zaddrs); - + for (auto & entry : sproutEntries) { UniValue obj(UniValue::VOBJ); obj.pushKV("txid", entry.jsop.hash.ToString()); @@ -3380,12 +3386,19 @@ CAmount getBalanceTaddr(std::string transparentAddress, int minDepth=1, bool ign return balance; } -CAmount getBalanceZaddr(std::string address, int minDepth = 1, bool ignoreUnspendable=true) { +CAmount getBalanceZaddr(std::string address, int minDepth, int maxDepth, bool ignoreUnspendable) { CAmount balance = 0; std::vector sproutEntries; std::vector saplingEntries; LOCK2(cs_main, pwalletMain->cs_wallet); - pwalletMain->GetFilteredNotes(sproutEntries, saplingEntries, address, minDepth, true, ignoreUnspendable); + + std::set filterAddresses; + if (address.length() > 0) { + KeyIO keyIO(Params()); + filterAddresses.insert(keyIO.DecodePaymentAddress(address)); + } + + pwalletMain->GetFilteredNotes(sproutEntries, saplingEntries, filterAddresses, minDepth, maxDepth, true, ignoreUnspendable); for (auto & entry : sproutEntries) { balance += CAmount(entry.note.value()); } @@ -3582,7 +3595,7 @@ UniValue z_getbalance(const UniValue& params, bool fHelp) if (fromTaddr) { nBalance = getBalanceTaddr(fromaddress, nMinDepth, false); } else { - nBalance = getBalanceZaddr(fromaddress, nMinDepth, false); + nBalance = getBalanceZaddr(fromaddress, nMinDepth, INT_MAX, false); } // inZat @@ -3644,7 +3657,7 @@ UniValue z_gettotalbalance(const UniValue& params, bool fHelp) // pwalletMain->GetBalance() does not accept min depth parameter // so we use our own method to get balance of utxos. CAmount nBalance = getBalanceTaddr("", nMinDepth, !fIncludeWatchonly); - CAmount nPrivateBalance = getBalanceZaddr("", nMinDepth, !fIncludeWatchonly); + CAmount nPrivateBalance = getBalanceZaddr("", nMinDepth, INT_MAX, !fIncludeWatchonly); CAmount nTotalBalance = nBalance + nPrivateBalance; UniValue result(UniValue::VOBJ); result.pushKV("transparent", FormatMoney(nBalance)); @@ -4282,7 +4295,7 @@ UniValue z_sendmany(const UniValue& params, bool fHelp) CMutableTransaction contextualTx = CreateNewContextualCMutableTransaction(Params().GetConsensus(), nextBlockHeight); bool isShielded = !fromTaddr || zaddrRecipients.size() > 0; if (contextualTx.nVersion == 1 && isShielded) { - contextualTx.nVersion = 2; // Tx format should support vJoinSplits + contextualTx.nVersion = 2; // Tx format should support vJoinSplits } // Create operation and add to global queue @@ -4628,7 +4641,7 @@ UniValue z_shieldcoinbase(const UniValue& params, bool fHelp) CMutableTransaction contextualTx = CreateNewContextualCMutableTransaction( Params().GetConsensus(), nextBlockHeight); if (contextualTx.nVersion == 1) { - contextualTx.nVersion = 2; // Tx format should support vJoinSplit + contextualTx.nVersion = 2; // Tx format should support vJoinSplit } // Create operation and add to global queue