Auto merge of #3932 - LarryRuane:3823-deadlock-simple, r=daira

simplify locking, merge cs_SpendingKeyStore into cs_KeyStore

Closes #3823.
This commit is contained in:
Homu 2019-11-05 14:10:18 -08:00
commit 3bc8c07563
5 changed files with 161 additions and 202 deletions

View File

@ -25,7 +25,7 @@ bool CKeyStore::AddKey(const CKey &key) {
bool CBasicKeyStore::SetHDSeed(const HDSeed& seed) bool CBasicKeyStore::SetHDSeed(const HDSeed& seed)
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
if (!hdSeed.IsNull()) { if (!hdSeed.IsNull()) {
// Don't allow an existing seed to be changed. We can maybe relax this // Don't allow an existing seed to be changed. We can maybe relax this
// restriction later once we have worked out the UX implications. // restriction later once we have worked out the UX implications.
@ -37,13 +37,13 @@ bool CBasicKeyStore::SetHDSeed(const HDSeed& seed)
bool CBasicKeyStore::HaveHDSeed() const bool CBasicKeyStore::HaveHDSeed() const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
return !hdSeed.IsNull(); return !hdSeed.IsNull();
} }
bool CBasicKeyStore::GetHDSeed(HDSeed& seedOut) const bool CBasicKeyStore::GetHDSeed(HDSeed& seedOut) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
if (hdSeed.IsNull()) { if (hdSeed.IsNull()) {
return false; return false;
} else { } else {
@ -115,7 +115,7 @@ bool CBasicKeyStore::HaveWatchOnly() const
bool CBasicKeyStore::AddSproutSpendingKey(const libzcash::SproutSpendingKey &sk) bool CBasicKeyStore::AddSproutSpendingKey(const libzcash::SproutSpendingKey &sk)
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
auto address = sk.address(); auto address = sk.address();
mapSproutSpendingKeys[address] = sk; mapSproutSpendingKeys[address] = sk;
mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(sk.receiving_key()))); mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(sk.receiving_key())));
@ -127,7 +127,7 @@ bool CBasicKeyStore::AddSaplingSpendingKey(
const libzcash::SaplingExtendedSpendingKey &sk, const libzcash::SaplingExtendedSpendingKey &sk,
const libzcash::SaplingPaymentAddress &defaultAddr) const libzcash::SaplingPaymentAddress &defaultAddr)
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
auto fvk = sk.expsk.full_viewing_key(); auto fvk = sk.expsk.full_viewing_key();
// if SaplingFullViewingKey is not in SaplingFullViewingKeyMap, add it // if SaplingFullViewingKey is not in SaplingFullViewingKeyMap, add it
@ -142,7 +142,7 @@ bool CBasicKeyStore::AddSaplingSpendingKey(
bool CBasicKeyStore::AddSproutViewingKey(const libzcash::SproutViewingKey &vk) bool CBasicKeyStore::AddSproutViewingKey(const libzcash::SproutViewingKey &vk)
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
auto address = vk.address(); auto address = vk.address();
mapSproutViewingKeys[address] = vk; mapSproutViewingKeys[address] = vk;
mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(vk.sk_enc))); mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(vk.sk_enc)));
@ -153,7 +153,7 @@ bool CBasicKeyStore::AddSaplingFullViewingKey(
const libzcash::SaplingFullViewingKey &fvk, const libzcash::SaplingFullViewingKey &fvk,
const libzcash::SaplingPaymentAddress &defaultAddr) const libzcash::SaplingPaymentAddress &defaultAddr)
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
auto ivk = fvk.in_viewing_key(); auto ivk = fvk.in_viewing_key();
mapSaplingFullViewingKeys[ivk] = fvk; mapSaplingFullViewingKeys[ivk] = fvk;
@ -167,7 +167,7 @@ bool CBasicKeyStore::AddSaplingIncomingViewingKey(
const libzcash::SaplingIncomingViewingKey &ivk, const libzcash::SaplingIncomingViewingKey &ivk,
const libzcash::SaplingPaymentAddress &addr) const libzcash::SaplingPaymentAddress &addr)
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
// Add addr -> SaplingIncomingViewing to SaplingIncomingViewingKeyMap // Add addr -> SaplingIncomingViewing to SaplingIncomingViewingKeyMap
mapSaplingIncomingViewingKeys[addr] = ivk; mapSaplingIncomingViewingKeys[addr] = ivk;
@ -177,26 +177,26 @@ bool CBasicKeyStore::AddSaplingIncomingViewingKey(
bool CBasicKeyStore::RemoveSproutViewingKey(const libzcash::SproutViewingKey &vk) bool CBasicKeyStore::RemoveSproutViewingKey(const libzcash::SproutViewingKey &vk)
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
mapSproutViewingKeys.erase(vk.address()); mapSproutViewingKeys.erase(vk.address());
return true; return true;
} }
bool CBasicKeyStore::HaveSproutViewingKey(const libzcash::SproutPaymentAddress &address) const bool CBasicKeyStore::HaveSproutViewingKey(const libzcash::SproutPaymentAddress &address) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
return mapSproutViewingKeys.count(address) > 0; return mapSproutViewingKeys.count(address) > 0;
} }
bool CBasicKeyStore::HaveSaplingFullViewingKey(const libzcash::SaplingIncomingViewingKey &ivk) const bool CBasicKeyStore::HaveSaplingFullViewingKey(const libzcash::SaplingIncomingViewingKey &ivk) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
return mapSaplingFullViewingKeys.count(ivk) > 0; return mapSaplingFullViewingKeys.count(ivk) > 0;
} }
bool CBasicKeyStore::HaveSaplingIncomingViewingKey(const libzcash::SaplingPaymentAddress &addr) const bool CBasicKeyStore::HaveSaplingIncomingViewingKey(const libzcash::SaplingPaymentAddress &addr) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
return mapSaplingIncomingViewingKeys.count(addr) > 0; return mapSaplingIncomingViewingKeys.count(addr) > 0;
} }
@ -204,7 +204,7 @@ bool CBasicKeyStore::GetSproutViewingKey(
const libzcash::SproutPaymentAddress &address, const libzcash::SproutPaymentAddress &address,
libzcash::SproutViewingKey &vkOut) const libzcash::SproutViewingKey &vkOut) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
SproutViewingKeyMap::const_iterator mi = mapSproutViewingKeys.find(address); SproutViewingKeyMap::const_iterator mi = mapSproutViewingKeys.find(address);
if (mi != mapSproutViewingKeys.end()) { if (mi != mapSproutViewingKeys.end()) {
vkOut = mi->second; vkOut = mi->second;
@ -216,7 +216,7 @@ bool CBasicKeyStore::GetSproutViewingKey(
bool CBasicKeyStore::GetSaplingFullViewingKey(const libzcash::SaplingIncomingViewingKey &ivk, bool CBasicKeyStore::GetSaplingFullViewingKey(const libzcash::SaplingIncomingViewingKey &ivk,
libzcash::SaplingFullViewingKey &fvkOut) const libzcash::SaplingFullViewingKey &fvkOut) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
SaplingFullViewingKeyMap::const_iterator mi = mapSaplingFullViewingKeys.find(ivk); SaplingFullViewingKeyMap::const_iterator mi = mapSaplingFullViewingKeys.find(ivk);
if (mi != mapSaplingFullViewingKeys.end()) { if (mi != mapSaplingFullViewingKeys.end()) {
fvkOut = mi->second; fvkOut = mi->second;
@ -228,7 +228,7 @@ bool CBasicKeyStore::GetSaplingFullViewingKey(const libzcash::SaplingIncomingVie
bool CBasicKeyStore::GetSaplingIncomingViewingKey(const libzcash::SaplingPaymentAddress &addr, bool CBasicKeyStore::GetSaplingIncomingViewingKey(const libzcash::SaplingPaymentAddress &addr,
libzcash::SaplingIncomingViewingKey &ivkOut) const libzcash::SaplingIncomingViewingKey &ivkOut) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
SaplingIncomingViewingKeyMap::const_iterator mi = mapSaplingIncomingViewingKeys.find(addr); SaplingIncomingViewingKeyMap::const_iterator mi = mapSaplingIncomingViewingKeys.find(addr);
if (mi != mapSaplingIncomingViewingKeys.end()) { if (mi != mapSaplingIncomingViewingKeys.end()) {
ivkOut = mi->second; ivkOut = mi->second;
@ -242,6 +242,7 @@ bool CBasicKeyStore::GetSaplingExtendedSpendingKey(const libzcash::SaplingPaymen
libzcash::SaplingIncomingViewingKey ivk; libzcash::SaplingIncomingViewingKey ivk;
libzcash::SaplingFullViewingKey fvk; libzcash::SaplingFullViewingKey fvk;
LOCK(cs_KeyStore);
return GetSaplingIncomingViewingKey(addr, ivk) && return GetSaplingIncomingViewingKey(addr, ivk) &&
GetSaplingFullViewingKey(ivk, fvk) && GetSaplingFullViewingKey(ivk, fvk) &&
GetSaplingSpendingKey(fvk, extskOut); GetSaplingSpendingKey(fvk, extskOut);

View File

@ -23,7 +23,6 @@ class CKeyStore
{ {
protected: protected:
mutable CCriticalSection cs_KeyStore; mutable CCriticalSection cs_KeyStore;
mutable CCriticalSection cs_SpendingKeyStore;
public: public:
virtual ~CKeyStore() {} virtual ~CKeyStore() {}
@ -185,7 +184,7 @@ public:
{ {
bool result; bool result;
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
result = (mapSproutSpendingKeys.count(address) > 0); result = (mapSproutSpendingKeys.count(address) > 0);
} }
return result; return result;
@ -193,7 +192,7 @@ public:
bool GetSproutSpendingKey(const libzcash::SproutPaymentAddress &address, libzcash::SproutSpendingKey &skOut) const bool GetSproutSpendingKey(const libzcash::SproutPaymentAddress &address, libzcash::SproutSpendingKey &skOut) const
{ {
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
SproutSpendingKeyMap::const_iterator mi = mapSproutSpendingKeys.find(address); SproutSpendingKeyMap::const_iterator mi = mapSproutSpendingKeys.find(address);
if (mi != mapSproutSpendingKeys.end()) if (mi != mapSproutSpendingKeys.end())
{ {
@ -206,7 +205,7 @@ public:
bool GetNoteDecryptor(const libzcash::SproutPaymentAddress &address, ZCNoteDecryption &decOut) const bool GetNoteDecryptor(const libzcash::SproutPaymentAddress &address, ZCNoteDecryption &decOut) const
{ {
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
NoteDecryptorMap::const_iterator mi = mapNoteDecryptors.find(address); NoteDecryptorMap::const_iterator mi = mapNoteDecryptors.find(address);
if (mi != mapNoteDecryptors.end()) if (mi != mapNoteDecryptors.end())
{ {
@ -220,7 +219,7 @@ public:
{ {
setAddress.clear(); setAddress.clear();
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
SproutSpendingKeyMap::const_iterator mi = mapSproutSpendingKeys.begin(); SproutSpendingKeyMap::const_iterator mi = mapSproutSpendingKeys.begin();
while (mi != mapSproutSpendingKeys.end()) while (mi != mapSproutSpendingKeys.end())
{ {
@ -244,7 +243,7 @@ public:
{ {
bool result; bool result;
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
result = (mapSaplingSpendingKeys.count(fvk) > 0); result = (mapSaplingSpendingKeys.count(fvk) > 0);
} }
return result; return result;
@ -252,7 +251,7 @@ public:
bool GetSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk, libzcash::SaplingExtendedSpendingKey &skOut) const bool GetSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk, libzcash::SaplingExtendedSpendingKey &skOut) const
{ {
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
SaplingSpendingKeyMap::const_iterator mi = mapSaplingSpendingKeys.find(fvk); SaplingSpendingKeyMap::const_iterator mi = mapSaplingSpendingKeys.find(fvk);
if (mi != mapSaplingSpendingKeys.end()) if (mi != mapSaplingSpendingKeys.end())
@ -288,7 +287,7 @@ public:
{ {
setAddress.clear(); setAddress.clear();
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
auto mi = mapSaplingIncomingViewingKeys.begin(); auto mi = mapSaplingIncomingViewingKeys.begin();
while (mi != mapSaplingIncomingViewingKeys.end()) while (mi != mapSaplingIncomingViewingKeys.end())
{ {

View File

@ -185,9 +185,9 @@ static bool DecryptSaplingSpendingKey(const CKeyingMaterial& vMasterKey,
return sk.expsk.full_viewing_key() == extfvk.fvk; return sk.expsk.full_viewing_key() == extfvk.fvk;
} }
// cs_KeyStore lock must be held by caller
bool CCryptoKeyStore::SetCrypted() bool CCryptoKeyStore::SetCrypted()
{ {
LOCK2(cs_KeyStore, cs_SpendingKeyStore);
if (fUseCrypto) if (fUseCrypto)
return true; return true;
if (!(mapKeys.empty() && mapSproutSpendingKeys.empty() && mapSaplingSpendingKeys.empty())) if (!(mapKeys.empty() && mapSproutSpendingKeys.empty() && mapSaplingSpendingKeys.empty()))
@ -198,11 +198,10 @@ bool CCryptoKeyStore::SetCrypted()
bool CCryptoKeyStore::Lock() bool CCryptoKeyStore::Lock()
{ {
if (!SetCrypted())
return false;
{ {
LOCK(cs_KeyStore); LOCK(cs_KeyStore);
if (!SetCrypted())
return false;
vMasterKey.clear(); vMasterKey.clear();
} }
@ -213,7 +212,7 @@ bool CCryptoKeyStore::Lock()
bool CCryptoKeyStore::Unlock(const CKeyingMaterial& vMasterKeyIn) bool CCryptoKeyStore::Unlock(const CKeyingMaterial& vMasterKeyIn)
{ {
{ {
LOCK2(cs_KeyStore, cs_SpendingKeyStore); LOCK(cs_KeyStore);
if (!SetCrypted()) if (!SetCrypted())
return false; return false;
@ -290,8 +289,8 @@ bool CCryptoKeyStore::Unlock(const CKeyingMaterial& vMasterKeyIn)
bool CCryptoKeyStore::SetHDSeed(const HDSeed& seed) bool CCryptoKeyStore::SetHDSeed(const HDSeed& seed)
{ {
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
if (!IsCrypted()) { if (!fUseCrypto) {
return CBasicKeyStore::SetHDSeed(seed); return CBasicKeyStore::SetHDSeed(seed);
} }
@ -316,27 +315,25 @@ bool CCryptoKeyStore::SetCryptedHDSeed(
const uint256& seedFp, const uint256& seedFp,
const std::vector<unsigned char>& vchCryptedSecret) const std::vector<unsigned char>& vchCryptedSecret)
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_SpendingKeyStore); if (!fUseCrypto) {
if (!IsCrypted()) { return false;
return false;
}
if (!cryptedHDSeed.first.IsNull()) {
// Don't allow an existing seed to be changed. We can maybe relax this
// restriction later once we have worked out the UX implications.
return false;
}
cryptedHDSeed = std::make_pair(seedFp, vchCryptedSecret);
} }
if (!cryptedHDSeed.first.IsNull()) {
// Don't allow an existing seed to be changed. We can maybe relax this
// restriction later once we have worked out the UX implications.
return false;
}
cryptedHDSeed = std::make_pair(seedFp, vchCryptedSecret);
return true; return true;
} }
bool CCryptoKeyStore::HaveHDSeed() const bool CCryptoKeyStore::HaveHDSeed() const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
if (!IsCrypted()) if (!fUseCrypto)
return CBasicKeyStore::HaveHDSeed(); return CBasicKeyStore::HaveHDSeed();
return !cryptedHDSeed.second.empty(); return !cryptedHDSeed.second.empty();
@ -344,8 +341,8 @@ bool CCryptoKeyStore::HaveHDSeed() const
bool CCryptoKeyStore::GetHDSeed(HDSeed& seedOut) const bool CCryptoKeyStore::GetHDSeed(HDSeed& seedOut) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
if (!IsCrypted()) if (!fUseCrypto)
return CBasicKeyStore::GetHDSeed(seedOut); return CBasicKeyStore::GetHDSeed(seedOut);
if (cryptedHDSeed.second.empty()) if (cryptedHDSeed.second.empty())
@ -356,125 +353,106 @@ bool CCryptoKeyStore::GetHDSeed(HDSeed& seedOut) const
bool CCryptoKeyStore::AddKeyPubKey(const CKey& key, const CPubKey &pubkey) bool CCryptoKeyStore::AddKeyPubKey(const CKey& key, const CPubKey &pubkey)
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_KeyStore); if (!fUseCrypto)
if (!IsCrypted()) return CBasicKeyStore::AddKeyPubKey(key, pubkey);
return CBasicKeyStore::AddKeyPubKey(key, pubkey);
if (IsLocked()) if (IsLocked())
return false; return false;
std::vector<unsigned char> vchCryptedSecret; std::vector<unsigned char> vchCryptedSecret;
CKeyingMaterial vchSecret(key.begin(), key.end()); CKeyingMaterial vchSecret(key.begin(), key.end());
if (!EncryptSecret(vMasterKey, vchSecret, pubkey.GetHash(), vchCryptedSecret)) if (!EncryptSecret(vMasterKey, vchSecret, pubkey.GetHash(), vchCryptedSecret))
return false; return false;
if (!AddCryptedKey(pubkey, vchCryptedSecret)) return AddCryptedKey(pubkey, vchCryptedSecret);
return false;
}
return true;
} }
bool CCryptoKeyStore::AddCryptedKey(const CPubKey &vchPubKey, const std::vector<unsigned char> &vchCryptedSecret) bool CCryptoKeyStore::AddCryptedKey(const CPubKey &vchPubKey, const std::vector<unsigned char> &vchCryptedSecret)
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_KeyStore); if (!SetCrypted())
if (!SetCrypted()) return false;
return false;
mapCryptedKeys[vchPubKey.GetID()] = make_pair(vchPubKey, vchCryptedSecret); mapCryptedKeys[vchPubKey.GetID()] = make_pair(vchPubKey, vchCryptedSecret);
}
return true; return true;
} }
bool CCryptoKeyStore::GetKey(const CKeyID &address, CKey& keyOut) const bool CCryptoKeyStore::GetKey(const CKeyID &address, CKey& keyOut) const
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_KeyStore); if (!fUseCrypto)
if (!IsCrypted()) return CBasicKeyStore::GetKey(address, keyOut);
return CBasicKeyStore::GetKey(address, keyOut);
CryptedKeyMap::const_iterator mi = mapCryptedKeys.find(address); CryptedKeyMap::const_iterator mi = mapCryptedKeys.find(address);
if (mi != mapCryptedKeys.end()) if (mi != mapCryptedKeys.end())
{ {
const CPubKey &vchPubKey = (*mi).second.first; const CPubKey &vchPubKey = (*mi).second.first;
const std::vector<unsigned char> &vchCryptedSecret = (*mi).second.second; const std::vector<unsigned char> &vchCryptedSecret = (*mi).second.second;
return DecryptKey(vMasterKey, vchCryptedSecret, vchPubKey, keyOut); return DecryptKey(vMasterKey, vchCryptedSecret, vchPubKey, keyOut);
}
} }
return false; return false;
} }
bool CCryptoKeyStore::GetPubKey(const CKeyID &address, CPubKey& vchPubKeyOut) const bool CCryptoKeyStore::GetPubKey(const CKeyID &address, CPubKey& vchPubKeyOut) const
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_KeyStore); if (!fUseCrypto)
if (!IsCrypted()) return CKeyStore::GetPubKey(address, vchPubKeyOut);
return CKeyStore::GetPubKey(address, vchPubKeyOut);
CryptedKeyMap::const_iterator mi = mapCryptedKeys.find(address); CryptedKeyMap::const_iterator mi = mapCryptedKeys.find(address);
if (mi != mapCryptedKeys.end()) if (mi != mapCryptedKeys.end())
{ {
vchPubKeyOut = (*mi).second.first; vchPubKeyOut = (*mi).second.first;
return true; return true;
}
} }
return false; return false;
} }
bool CCryptoKeyStore::AddSproutSpendingKey(const libzcash::SproutSpendingKey &sk) bool CCryptoKeyStore::AddSproutSpendingKey(const libzcash::SproutSpendingKey &sk)
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_SpendingKeyStore); if (!fUseCrypto)
if (!IsCrypted()) return CBasicKeyStore::AddSproutSpendingKey(sk);
return CBasicKeyStore::AddSproutSpendingKey(sk);
if (IsLocked()) if (IsLocked())
return false; return false;
std::vector<unsigned char> vchCryptedSecret; std::vector<unsigned char> vchCryptedSecret;
CSecureDataStream ss(SER_NETWORK, PROTOCOL_VERSION); CSecureDataStream ss(SER_NETWORK, PROTOCOL_VERSION);
ss << sk; ss << sk;
CKeyingMaterial vchSecret(ss.begin(), ss.end()); CKeyingMaterial vchSecret(ss.begin(), ss.end());
auto address = sk.address(); auto address = sk.address();
if (!EncryptSecret(vMasterKey, vchSecret, address.GetHash(), vchCryptedSecret)) if (!EncryptSecret(vMasterKey, vchSecret, address.GetHash(), vchCryptedSecret))
return false; return false;
if (!AddCryptedSproutSpendingKey(address, sk.receiving_key(), vchCryptedSecret)) return AddCryptedSproutSpendingKey(address, sk.receiving_key(), vchCryptedSecret);
return false;
}
return true;
} }
bool CCryptoKeyStore::AddSaplingSpendingKey( bool CCryptoKeyStore::AddSaplingSpendingKey(
const libzcash::SaplingExtendedSpendingKey &sk, const libzcash::SaplingExtendedSpendingKey &sk,
const libzcash::SaplingPaymentAddress &defaultAddr) const libzcash::SaplingPaymentAddress &defaultAddr)
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_SpendingKeyStore); if (!fUseCrypto) {
if (!IsCrypted()) { return CBasicKeyStore::AddSaplingSpendingKey(sk, defaultAddr);
return CBasicKeyStore::AddSaplingSpendingKey(sk, defaultAddr);
}
if (IsLocked()) {
return false;
}
std::vector<unsigned char> vchCryptedSecret;
CSecureDataStream ss(SER_NETWORK, PROTOCOL_VERSION);
ss << sk;
CKeyingMaterial vchSecret(ss.begin(), ss.end());
auto extfvk = sk.ToXFVK();
if (!EncryptSecret(vMasterKey, vchSecret, extfvk.fvk.GetFingerprint(), vchCryptedSecret)) {
return false;
}
if (!AddCryptedSaplingSpendingKey(extfvk, vchCryptedSecret, defaultAddr)) {
return false;
}
} }
return true;
if (IsLocked()) {
return false;
}
std::vector<unsigned char> vchCryptedSecret;
CSecureDataStream ss(SER_NETWORK, PROTOCOL_VERSION);
ss << sk;
CKeyingMaterial vchSecret(ss.begin(), ss.end());
auto extfvk = sk.ToXFVK();
if (!EncryptSecret(vMasterKey, vchSecret, extfvk.fvk.GetFingerprint(), vchCryptedSecret)) {
return false;
}
return AddCryptedSaplingSpendingKey(extfvk, vchCryptedSecret, defaultAddr);
} }
bool CCryptoKeyStore::AddCryptedSproutSpendingKey( bool CCryptoKeyStore::AddCryptedSproutSpendingKey(
@ -482,14 +460,12 @@ bool CCryptoKeyStore::AddCryptedSproutSpendingKey(
const libzcash::ReceivingKey &rk, const libzcash::ReceivingKey &rk,
const std::vector<unsigned char> &vchCryptedSecret) const std::vector<unsigned char> &vchCryptedSecret)
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_SpendingKeyStore); if (!SetCrypted())
if (!SetCrypted()) return false;
return false;
mapCryptedSproutSpendingKeys[address] = vchCryptedSecret; mapCryptedSproutSpendingKeys[address] = vchCryptedSecret;
mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(rk))); mapNoteDecryptors.insert(std::make_pair(address, ZCNoteDecryption(rk)));
}
return true; return true;
} }
@ -498,51 +474,45 @@ bool CCryptoKeyStore::AddCryptedSaplingSpendingKey(
const std::vector<unsigned char> &vchCryptedSecret, const std::vector<unsigned char> &vchCryptedSecret,
const libzcash::SaplingPaymentAddress &defaultAddr) const libzcash::SaplingPaymentAddress &defaultAddr)
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_SpendingKeyStore); if (!SetCrypted()) {
if (!SetCrypted()) { return false;
return false;
}
// if SaplingFullViewingKey is not in SaplingFullViewingKeyMap, add it
if (!AddSaplingFullViewingKey(extfvk.fvk, defaultAddr)) {
return false;
}
mapCryptedSaplingSpendingKeys[extfvk] = vchCryptedSecret;
} }
// if SaplingFullViewingKey is not in SaplingFullViewingKeyMap, add it
if (!AddSaplingFullViewingKey(extfvk.fvk, defaultAddr)) {
return false;
}
mapCryptedSaplingSpendingKeys[extfvk] = vchCryptedSecret;
return true; return true;
} }
bool CCryptoKeyStore::GetSproutSpendingKey(const libzcash::SproutPaymentAddress &address, libzcash::SproutSpendingKey &skOut) const bool CCryptoKeyStore::GetSproutSpendingKey(const libzcash::SproutPaymentAddress &address, libzcash::SproutSpendingKey &skOut) const
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_SpendingKeyStore); if (!fUseCrypto)
if (!IsCrypted()) return CBasicKeyStore::GetSproutSpendingKey(address, skOut);
return CBasicKeyStore::GetSproutSpendingKey(address, skOut);
CryptedSproutSpendingKeyMap::const_iterator mi = mapCryptedSproutSpendingKeys.find(address); CryptedSproutSpendingKeyMap::const_iterator mi = mapCryptedSproutSpendingKeys.find(address);
if (mi != mapCryptedSproutSpendingKeys.end()) if (mi != mapCryptedSproutSpendingKeys.end())
{ {
const std::vector<unsigned char> &vchCryptedSecret = (*mi).second; const std::vector<unsigned char> &vchCryptedSecret = (*mi).second;
return DecryptSproutSpendingKey(vMasterKey, vchCryptedSecret, address, skOut); return DecryptSproutSpendingKey(vMasterKey, vchCryptedSecret, address, skOut);
}
} }
return false; return false;
} }
bool CCryptoKeyStore::GetSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk, libzcash::SaplingExtendedSpendingKey &skOut) const bool CCryptoKeyStore::GetSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk, libzcash::SaplingExtendedSpendingKey &skOut) const
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_SpendingKeyStore); if (!fUseCrypto)
if (!IsCrypted()) return CBasicKeyStore::GetSaplingSpendingKey(fvk, skOut);
return CBasicKeyStore::GetSaplingSpendingKey(fvk, skOut);
for (auto entry : mapCryptedSaplingSpendingKeys) { for (auto entry : mapCryptedSaplingSpendingKeys) {
if (entry.first.fvk == fvk) { if (entry.first.fvk == fvk) {
const std::vector<unsigned char> &vchCryptedSecret = entry.second; const std::vector<unsigned char> &vchCryptedSecret = entry.second;
return DecryptSaplingSpendingKey(vMasterKey, vchCryptedSecret, entry.first, skOut); return DecryptSaplingSpendingKey(vMasterKey, vchCryptedSecret, entry.first, skOut);
}
} }
} }
return false; return false;
@ -551,8 +521,8 @@ bool CCryptoKeyStore::GetSaplingSpendingKey(const libzcash::SaplingFullViewingKe
bool CCryptoKeyStore::EncryptKeys(CKeyingMaterial& vMasterKeyIn) bool CCryptoKeyStore::EncryptKeys(CKeyingMaterial& vMasterKeyIn)
{ {
{ {
LOCK2(cs_KeyStore, cs_SpendingKeyStore); LOCK(cs_KeyStore);
if (!mapCryptedKeys.empty() || IsCrypted()) if (!mapCryptedKeys.empty() || fUseCrypto)
return false; return false;
fUseCrypto = true; fUseCrypto = true;

View File

@ -157,19 +157,14 @@ public:
bool IsCrypted() const bool IsCrypted() const
{ {
LOCK(cs_KeyStore);
return fUseCrypto; return fUseCrypto;
} }
bool IsLocked() const bool IsLocked() const
{ {
if (!IsCrypted()) LOCK(cs_KeyStore);
return false; return fUseCrypto && vMasterKey.empty();
bool result;
{
LOCK(cs_KeyStore);
result = vMasterKey.empty();
}
return result;
} }
bool Lock(); bool Lock();
@ -183,19 +178,17 @@ public:
bool AddKeyPubKey(const CKey& key, const CPubKey &pubkey); bool AddKeyPubKey(const CKey& key, const CPubKey &pubkey);
bool HaveKey(const CKeyID &address) const bool HaveKey(const CKeyID &address) const
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_KeyStore); if (!fUseCrypto)
if (!IsCrypted()) return CBasicKeyStore::HaveKey(address);
return CBasicKeyStore::HaveKey(address); return mapCryptedKeys.count(address) > 0;
return mapCryptedKeys.count(address) > 0;
}
return false;
} }
bool GetKey(const CKeyID &address, CKey& keyOut) const; bool GetKey(const CKeyID &address, CKey& keyOut) const;
bool GetPubKey(const CKeyID &address, CPubKey& vchPubKeyOut) const; bool GetPubKey(const CKeyID &address, CPubKey& vchPubKeyOut) const;
void GetKeys(std::set<CKeyID> &setAddress) const void GetKeys(std::set<CKeyID> &setAddress) const
{ {
if (!IsCrypted()) LOCK(cs_KeyStore);
if (!fUseCrypto)
{ {
CBasicKeyStore::GetKeys(setAddress); CBasicKeyStore::GetKeys(setAddress);
return; return;
@ -215,18 +208,16 @@ public:
bool AddSproutSpendingKey(const libzcash::SproutSpendingKey &sk); bool AddSproutSpendingKey(const libzcash::SproutSpendingKey &sk);
bool HaveSproutSpendingKey(const libzcash::SproutPaymentAddress &address) const bool HaveSproutSpendingKey(const libzcash::SproutPaymentAddress &address) const
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_SpendingKeyStore); if (!fUseCrypto)
if (!IsCrypted()) return CBasicKeyStore::HaveSproutSpendingKey(address);
return CBasicKeyStore::HaveSproutSpendingKey(address); return mapCryptedSproutSpendingKeys.count(address) > 0;
return mapCryptedSproutSpendingKeys.count(address) > 0;
}
return false;
} }
bool GetSproutSpendingKey(const libzcash::SproutPaymentAddress &address, libzcash::SproutSpendingKey &skOut) const; bool GetSproutSpendingKey(const libzcash::SproutPaymentAddress &address, libzcash::SproutSpendingKey &skOut) const;
void GetSproutPaymentAddresses(std::set<libzcash::SproutPaymentAddress> &setAddress) const void GetSproutPaymentAddresses(std::set<libzcash::SproutPaymentAddress> &setAddress) const
{ {
if (!IsCrypted()) LOCK(cs_KeyStore);
if (!fUseCrypto)
{ {
CBasicKeyStore::GetSproutPaymentAddresses(setAddress); CBasicKeyStore::GetSproutPaymentAddresses(setAddress);
return; return;
@ -249,14 +240,12 @@ public:
const libzcash::SaplingPaymentAddress &defaultAddr); const libzcash::SaplingPaymentAddress &defaultAddr);
bool HaveSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk) const bool HaveSaplingSpendingKey(const libzcash::SaplingFullViewingKey &fvk) const
{ {
{ LOCK(cs_KeyStore);
LOCK(cs_SpendingKeyStore); if (!fUseCrypto)
if (!IsCrypted()) return CBasicKeyStore::HaveSaplingSpendingKey(fvk);
return CBasicKeyStore::HaveSaplingSpendingKey(fvk); for (auto entry : mapCryptedSaplingSpendingKeys) {
for (auto entry : mapCryptedSaplingSpendingKeys) { if (entry.first.fvk == fvk) {
if (entry.first.fvk == fvk) { return true;
return true;
}
} }
} }
return false; return false;

View File

@ -1802,7 +1802,7 @@ boost::optional<uint256> CWallet::GetSproutNoteNullifier(const JSDescription &js
*/ */
mapSproutNoteData_t CWallet::FindMySproutNotes(const CTransaction &tx) const mapSproutNoteData_t CWallet::FindMySproutNotes(const CTransaction &tx) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
uint256 hash = tx.GetHash(); uint256 hash = tx.GetHash();
mapSproutNoteData_t noteData; mapSproutNoteData_t noteData;
@ -1850,7 +1850,7 @@ mapSproutNoteData_t CWallet::FindMySproutNotes(const CTransaction &tx) const
*/ */
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> CWallet::FindMySaplingNotes(const CTransaction &tx) const std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> CWallet::FindMySaplingNotes(const CTransaction &tx) const
{ {
LOCK(cs_SpendingKeyStore); LOCK(cs_KeyStore);
uint256 hash = tx.GetHash(); uint256 hash = tx.GetHash();
mapSaplingNoteData_t noteData; mapSaplingNoteData_t noteData;