diff --git a/qa/pull-tester/rpc-tests.sh b/qa/pull-tester/rpc-tests.sh index 6a2348012..73d9ddb76 100755 --- a/qa/pull-tester/rpc-tests.sh +++ b/qa/pull-tester/rpc-tests.sh @@ -23,6 +23,7 @@ testScripts=( 'wallet_overwintertx.py' 'wallet_nullifiers.py' 'wallet_1941.py' + 'wallet_addresses.py' 'listtransactions.py' 'mempool_resurrect_test.py' 'txn_doublespend.py' diff --git a/qa/rpc-tests/wallet_addresses.py b/qa/rpc-tests/wallet_addresses.py new file mode 100755 index 000000000..0b9669972 --- /dev/null +++ b/qa/rpc-tests/wallet_addresses.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python2 +# Copyright (c) 2018 The Zcash developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. + +from test_framework.test_framework import BitcoinTestFramework +from test_framework.util import assert_equal, start_nodes + +# Test wallet address behaviour across network upgradesa\ +class WalletAddressesTest(BitcoinTestFramework): + + def setup_nodes(self): + return start_nodes(4, self.options.tmpdir, [[ + '-nuparams=5ba81b19:202', # Overwinter + '-nuparams=76b809bb:204', # Sapling + ]] * 4) + + def run_test(self): + def addr_checks(default_type): + # Check default type, as well as explicit types + types_and_addresses = [ + (default_type, self.nodes[0].z_getnewaddress()), + ('sprout', self.nodes[0].z_getnewaddress('sprout')), + ('sapling', self.nodes[0].z_getnewaddress('sapling')), + ] + + all_addresses = self.nodes[0].z_listaddresses() + + for addr_type, addr in types_and_addresses: + res = self.nodes[0].z_validateaddress(addr) + assert(res['isvalid']) + assert(res['ismine']) + assert_equal(res['type'], addr_type) + assert(addr in all_addresses) + + # Sanity-check the test harness + assert_equal(self.nodes[0].getblockcount(), 200) + + # Current height = 200 -> Sprout + # Default address type is Sprout + print "Testing height 200 (Sprout)" + addr_checks('sprout') + + self.nodes[0].generate(1) + self.sync_all() + + # Current height = 201 -> Sprout + # Default address type is Sprout + print "Testing height 201 (Sprout)" + addr_checks('sprout') + + self.nodes[0].generate(1) + self.sync_all() + + # Current height = 202 -> Overwinter + # Default address type is Sprout + print "Testing height 202 (Overwinter)" + addr_checks('sprout') + + self.nodes[0].generate(1) + self.sync_all() + + # Current height = 203 -> Overwinter + # Default address type is Sprout + print "Testing height 203 (Overwinter)" + addr_checks('sprout') + + self.nodes[0].generate(1) + self.sync_all() + + # Current height = 204 -> Sapling + # Default address type is Sprout + print "Testing height 204 (Sapling)" + addr_checks('sprout') + +if __name__ == '__main__': + WalletAddressesTest().main() diff --git a/src/bech32.cpp b/src/bech32.cpp index 2889f8f99..78c35b976 100644 --- a/src/bech32.cpp +++ b/src/bech32.cpp @@ -169,7 +169,7 @@ std::pair Decode(const std::string& str) { } if (lower && upper) return {}; size_t pos = str.rfind('1'); - if (str.size() > 90 || pos == str.npos || pos == 0 || pos + 7 > str.size()) { + if (str.size() > 1023 || pos == str.npos || pos == 0 || pos + 7 > str.size()) { return {}; } data values(str.size() - 1 - pos); diff --git a/src/keystore.h b/src/keystore.h index 76772dbf1..c856005c1 100644 --- a/src/keystore.h +++ b/src/keystore.h @@ -75,6 +75,7 @@ public: virtual bool GetSaplingIncomingViewingKey( const libzcash::SaplingPaymentAddress &addr, libzcash::SaplingIncomingViewingKey& ivkOut) const =0; + virtual void GetSaplingPaymentAddresses(std::set &setAddress) const =0; //! Support for viewing keys virtual bool AddViewingKey(const libzcash::SproutViewingKey &vk) =0; @@ -251,6 +252,19 @@ public: virtual bool GetSaplingIncomingViewingKey( const libzcash::SaplingPaymentAddress &addr, libzcash::SaplingIncomingViewingKey& ivkOut) const; + void GetSaplingPaymentAddresses(std::set &setAddress) const + { + setAddress.clear(); + { + LOCK(cs_SpendingKeyStore); + auto mi = mapSaplingIncomingViewingKeys.begin(); + while (mi != mapSaplingIncomingViewingKeys.end()) + { + setAddress.insert((*mi).first); + mi++; + } + } + } virtual bool AddViewingKey(const libzcash::SproutViewingKey &vk); virtual bool RemoveViewingKey(const libzcash::SproutViewingKey &vk); diff --git a/src/test/bech32_tests.cpp b/src/test/bech32_tests.cpp index f71ca1bf2..02252bcbf 100644 --- a/src/test/bech32_tests.cpp +++ b/src/test/bech32_tests.cpp @@ -28,6 +28,7 @@ BOOST_AUTO_TEST_CASE(bip173_testvectors_valid) "A12UEL5L", "a12uel5l", "an83characterlonghumanreadablepartthatcontainsthenumber1andtheexcludedcharactersbio1tt5tgs", + "an84characterslonghumanreadablepartthatcontainsthenumber1andtheexcludedcharactersbio1569pvx", "abcdef1qpzry9x8gf2tvdw0s3jn54khce6mua7lmqqqxw", "11qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqc8247j", "split1checkupstagehandshakeupstreamerranterredcaperred2y9e3w", @@ -48,7 +49,6 @@ BOOST_AUTO_TEST_CASE(bip173_testvectors_invalid) " 1nwldj5", "\x7f""1axkwrx", "\x80""1eym55h", - "an84characterslonghumanreadablepartthatcontainsthenumber1andtheexcludedcharactersbio1569pvx", "pzry9x0s0muk", "1pzry9x0s0muk", "x1b4n0q5v", diff --git a/src/test/rpc_wallet_tests.cpp b/src/test/rpc_wallet_tests.cpp index 5064526ba..4b478df40 100644 --- a/src/test/rpc_wallet_tests.cpp +++ b/src/test/rpc_wallet_tests.cpp @@ -561,6 +561,9 @@ BOOST_AUTO_TEST_CASE(rpc_wallet_z_importexport) std::set addrs; pwalletMain->GetPaymentAddresses(addrs); BOOST_CHECK(addrs.size()==0); + std::set saplingAddrs; + pwalletMain->GetSaplingPaymentAddresses(saplingAddrs); + BOOST_CHECK(saplingAddrs.empty()); // verify import and export key for (int i = 0; i < n1; i++) { @@ -586,7 +589,7 @@ BOOST_AUTO_TEST_CASE(rpc_wallet_z_importexport) // Verify we can list the keys imported BOOST_CHECK_NO_THROW(retValue = CallRPC("z_listaddresses")); UniValue arr = retValue.get_array(); - BOOST_CHECK(arr.size() == n1); + BOOST_CHECK(arr.size() == (2 * n1)); // Put addresses into a set std::unordered_set myaddrs; @@ -601,9 +604,10 @@ BOOST_AUTO_TEST_CASE(rpc_wallet_z_importexport) // Verify number of addresses stored in wallet is n1+n2 int numAddrs = myaddrs.size(); - BOOST_CHECK(numAddrs == n1+n2); + BOOST_CHECK(numAddrs == (2 * n1) + n2); pwalletMain->GetPaymentAddresses(addrs); - BOOST_CHECK(addrs.size()==numAddrs); + pwalletMain->GetSaplingPaymentAddresses(saplingAddrs); + BOOST_CHECK(addrs.size() + saplingAddrs.size() == numAddrs); // Ask wallet to list addresses BOOST_CHECK_NO_THROW(retValue = CallRPC("z_listaddresses")); diff --git a/src/wallet/gtest/test_wallet_zkeys.cpp b/src/wallet/gtest/test_wallet_zkeys.cpp index 9fe107ccc..22b8a905b 100644 --- a/src/wallet/gtest/test_wallet_zkeys.cpp +++ b/src/wallet/gtest/test_wallet_zkeys.cpp @@ -16,7 +16,15 @@ TEST(wallet_zkeys_tests, store_and_load_sapling_zkeys) { CWallet wallet; + // wallet should be empty + std::set addrs; + wallet.GetSaplingPaymentAddresses(addrs); + ASSERT_EQ(0, addrs.size()); + + // wallet should have one key auto address = wallet.GenerateNewSaplingZKey(); + wallet.GetSaplingPaymentAddresses(addrs); + ASSERT_EQ(1, addrs.size()); // verify wallet has incoming viewing key for the address ASSERT_TRUE(wallet.HaveSaplingIncomingViewingKey(address)); @@ -28,6 +36,17 @@ TEST(wallet_zkeys_tests, store_and_load_sapling_zkeys) { // verify wallet did add it auto fvk = sk.full_viewing_key(); ASSERT_TRUE(wallet.HaveSaplingSpendingKey(fvk)); + + // verify spending key stored correctly + libzcash::SaplingSpendingKey keyOut; + wallet.GetSaplingSpendingKey(fvk, keyOut); + ASSERT_EQ(sk, keyOut); + + // verify there are two keys + wallet.GetSaplingPaymentAddresses(addrs); + ASSERT_EQ(2, addrs.size()); + ASSERT_EQ(1, addrs.count(address)); + ASSERT_EQ(1, addrs.count(sk.default_address())); } /** diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index 696e2b288..d196da4e5 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -42,6 +42,9 @@ using namespace std; using namespace libzcash; +const std::string ADDR_TYPE_SPROUT = "sprout"; +const std::string ADDR_TYPE_SAPLING = "sapling"; + extern UniValue TxJoinSplitToJSON(const CTransaction& tx); int64_t nWalletUnlockTime; @@ -3100,15 +3103,21 @@ UniValue z_getnewaddress(const UniValue& params, bool fHelp) if (!EnsureWalletIsAvailable(fHelp)) return NullUniValue; - if (fHelp || params.size() > 0) + std::string defaultType = ADDR_TYPE_SPROUT; + + if (fHelp || params.size() > 1) throw runtime_error( - "z_getnewaddress\n" - "\nReturns a new zaddr for receiving payments.\n" + "z_getnewaddress ( type )\n" + "\nReturns a new shielded address for receiving payments.\n" + "\nWith no arguments, returns a Sprout address.\n" "\nArguments:\n" + "1. \"type\" (string, optional, default=\"" + defaultType + "\") The type of address. One of [\"" + + ADDR_TYPE_SPROUT + "\", \"" + ADDR_TYPE_SAPLING + "\"].\n" "\nResult:\n" - "\"zcashaddress\" (string) The new zaddr\n" + "\"zcashaddress\" (string) The new shielded address.\n" "\nExamples:\n" + HelpExampleCli("z_getnewaddress", "") + + HelpExampleCli("z_getnewaddress", ADDR_TYPE_SAPLING) + HelpExampleRpc("z_getnewaddress", "") ); @@ -3116,8 +3125,18 @@ UniValue z_getnewaddress(const UniValue& params, bool fHelp) EnsureWalletIsUnlocked(); - auto zaddr = pwalletMain->GenerateNewZKey(); - return EncodePaymentAddress(zaddr); + auto addrType = defaultType; + if (params.size() > 0) { + addrType = params[0].get_str(); + } + + if (addrType == ADDR_TYPE_SPROUT) { + return EncodePaymentAddress(pwalletMain->GenerateNewZKey()); + } else if (addrType == ADDR_TYPE_SAPLING) { + return EncodePaymentAddress(pwalletMain->GenerateNewSaplingZKey()); + } else { + throw JSONRPCError(RPC_INVALID_PARAMETER, "Invalid address type"); + } } @@ -3129,7 +3148,7 @@ UniValue z_listaddresses(const UniValue& params, bool fHelp) if (fHelp || params.size() > 1) throw runtime_error( "z_listaddresses ( includeWatchonly )\n" - "\nReturns the list of zaddr belonging to the wallet.\n" + "\nReturns the list of Sprout and Sapling shielded addresses belonging to the wallet.\n" "\nArguments:\n" "1. includeWatchonly (bool, optional, default=false) Also include watchonly addresses (see 'z_importviewingkey')\n" "\nResult:\n" @@ -3150,12 +3169,28 @@ UniValue z_listaddresses(const UniValue& params, bool fHelp) } UniValue ret(UniValue::VARR); - // TODO: Add Sapling support - std::set addresses; - pwalletMain->GetPaymentAddresses(addresses); - for (auto addr : addresses ) { - if (fIncludeWatchonly || pwalletMain->HaveSpendingKey(addr)) { - ret.push_back(EncodePaymentAddress(addr)); + { + std::set addresses; + pwalletMain->GetPaymentAddresses(addresses); + for (auto addr : addresses) { + if (fIncludeWatchonly || pwalletMain->HaveSpendingKey(addr)) { + ret.push_back(EncodePaymentAddress(addr)); + } + } + } + { + std::set addresses; + pwalletMain->GetSaplingPaymentAddresses(addresses); + libzcash::SaplingIncomingViewingKey ivk; + libzcash::SaplingFullViewingKey fvk; + for (auto addr : addresses) { + if (fIncludeWatchonly || ( + pwalletMain->GetSaplingIncomingViewingKey(addr, ivk) && + pwalletMain->GetSaplingFullViewingKey(ivk, fvk) && + pwalletMain->HaveSaplingSpendingKey(fvk) + )) { + ret.push_back(EncodePaymentAddress(addr)); + } } } return ret;