Merge pull request #1282 from LukasKorba/1281-Database-is-locked

[#1281] Database is locked
This commit is contained in:
Kris Nuttycombe 2023-09-20 13:39:52 -06:00 committed by GitHub
commit 09c97184ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 103 additions and 8 deletions

View File

@ -8,6 +8,22 @@
import Foundation
import SQLite
extension Connection {
public func scalarLocked<V: Value>(_ query: ScalarQuery<V?>) throws -> V.ValueType? {
globalDBLock.lock()
defer { globalDBLock.unlock() }
return try scalar(query)
}
public func scalarLocked<V: Value>(_ query: ScalarQuery<V>) throws -> V {
globalDBLock.lock()
defer { globalDBLock.unlock() }
return try scalar(query)
}
}
class TransactionSQLDAO: TransactionRepository {
enum NotesTableStructure {
static let transactionID = Expression<Int>("tx")
@ -41,7 +57,7 @@ class TransactionSQLDAO: TransactionRepository {
func countAll() async throws -> Int {
do {
return try connection().scalar(transactionsView.count)
return try connection().scalarLocked(transactionsView.count)
} catch {
throw ZcashError.transactionRepositoryCountAll(error)
}
@ -49,7 +65,7 @@ class TransactionSQLDAO: TransactionRepository {
func countUnmined() async throws -> Int {
do {
return try connection().scalar(transactionsView.filter(ZcashTransaction.Overview.Column.minedHeight == nil).count)
return try connection().scalarLocked(transactionsView.filter(ZcashTransaction.Overview.Column.minedHeight == nil).count)
} catch {
throw ZcashError.transactionRepositoryCountUnmined(error)
}
@ -160,11 +176,14 @@ class TransactionSQLDAO: TransactionRepository {
}
private func execute<Entity>(_ query: View, createEntity: (Row) throws -> Entity) throws -> [Entity] {
globalDBLock.lock()
defer { globalDBLock.unlock() }
do {
let entities = try connection()
.prepare(query)
.map(createEntity)
return entities
} catch {
if let error = error as? ZcashError {

View File

@ -109,6 +109,9 @@ class UnspentTransactionOutputSQLDAO: UnspentTransactionOutputRepository {
)
"""
do {
globalDBLock.lock()
defer { globalDBLock.unlock() }
try dbProvider.connection().run(stringStatement)
} catch {
throw ZcashError.unspentTransactionOutputDAOCreateTable(error)
@ -118,8 +121,11 @@ class UnspentTransactionOutputSQLDAO: UnspentTransactionOutputRepository {
/// - Throws: `unspentTransactionOutputDAOStore` if sqlite query fails.
func store(utxos: [UnspentTransactionOutputEntity]) async throws {
do {
globalDBLock.lock()
defer { globalDBLock.unlock() }
let db = try dbProvider.connection()
try dbProvider.connection().transaction {
try db.transaction {
for utxo in utxos.map({ $0 as? UTXO ?? $0.asUTXO() }) {
try db.run(table.insert(utxo))
}
@ -132,6 +138,9 @@ class UnspentTransactionOutputSQLDAO: UnspentTransactionOutputRepository {
/// - Throws: `unspentTransactionOutputDAOClearAll` if sqlite query fails.
func clearAll(address: String?) async throws {
do {
globalDBLock.lock()
defer { globalDBLock.unlock() }
if let tAddr = address {
try dbProvider.connection().run(table.filter(TableColumns.address == tAddr).delete())
} else {
@ -178,16 +187,16 @@ class UnspentTransactionOutputSQLDAO: UnspentTransactionOutputRepository {
/// - Throws: `unspentTransactionOutputDAOBalance` if sqlite query fails.
func balance(address: String, latestHeight: BlockHeight) async throws -> WalletBalance {
do {
let verified = try dbProvider.connection().scalar(
let verified = try dbProvider.connection().scalarLocked(
table.select(TableColumns.valueZat.sum)
.filter(TableColumns.address == address)
.filter(TableColumns.height <= latestHeight - ZcashSDK.defaultStaleTolerance)
) ?? 0
let total = try dbProvider.connection().scalar(
let total = try dbProvider.connection().scalarLocked(
table.select(TableColumns.valueZat.sum)
.filter(TableColumns.address == address)
) ?? 0
return WalletBalance(
verified: Zatoshi(Int64(verified)),
total: Zatoshi(Int64(total))

View File

@ -66,6 +66,9 @@ class AccountSQDAO: AccountRepository {
/// - `accountDAOGetAll` if sqlite query fetching account data failed.
func getAll() throws -> [AccountEntity] {
do {
globalDBLock.lock()
defer { globalDBLock.unlock() }
return try dbProvider.connection()
.prepare(table)
.map { row -> DbAccount in
@ -90,6 +93,9 @@ class AccountSQDAO: AccountRepository {
func findBy(account: Int) throws -> AccountEntity? {
let query = table.filter(TableColums.account == account).limit(1)
do {
globalDBLock.lock()
defer { globalDBLock.unlock() }
return try dbProvider.connection()
.prepare(query)
.map {
@ -119,6 +125,9 @@ class AccountSQDAO: AccountRepository {
let updatedRows: Int
do {
globalDBLock.lock()
defer { globalDBLock.unlock() }
updatedRows = try dbProvider.connection().run(table.filter(TableColums.account == acc.account).update(acc))
} catch {
throw ZcashError.accountDAOUpdate(error)

View File

@ -9,6 +9,8 @@
import Foundation
import libzcashlc
let globalDBLock = NSLock()
actor ZcashRustBackend: ZcashRustBackendWelding {
let minimumConfirmations: UInt32 = 10
let useZIP317Fees = false
@ -56,6 +58,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
let treeStateBytes = try treeState.serializedData(partial: false).bytes
globalDBLock.lock()
let ffiBinaryKeyPtr = zcashlc_create_account(
dbData.0,
dbData.1,
@ -66,6 +69,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
rUntil,
networkType.networkId
)
globalDBLock.unlock()
guard let ffiBinaryKeyPtr else {
throw ZcashError.rustCreateAccount(lastErrorMessage(fallback: "`createAccount` failed with unknown error"))
@ -84,6 +88,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
) async throws -> Data {
var contiguousTxIdBytes = ContiguousArray<UInt8>([UInt8](repeating: 0x0, count: 32))
globalDBLock.lock()
let success = contiguousTxIdBytes.withUnsafeMutableBufferPointer { txIdBytePtr in
usk.bytes.withUnsafeBufferPointer { uskPtr in
zcashlc_create_to_address(
@ -105,6 +110,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
)
}
}
globalDBLock.unlock()
guard success else {
throw ZcashError.rustCreateToAddress(lastErrorMessage(fallback: "`createToAddress` failed with unknown error"))
@ -116,6 +122,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func decryptAndStoreTransaction(txBytes: [UInt8], minedHeight: Int32) async throws {
globalDBLock.lock()
let result = zcashlc_decrypt_and_store_transaction(
dbData.0,
dbData.1,
@ -124,6 +131,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
UInt32(minedHeight),
networkType.networkId
)
globalDBLock.unlock()
guard result != 0 else {
throw ZcashError.rustDecryptAndStoreTransaction(lastErrorMessage(fallback: "`decryptAndStoreTransaction` failed with unknown error"))
@ -131,7 +139,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func getBalance(account: Int32) async throws -> Int64 {
globalDBLock.lock()
let balance = zcashlc_get_balance(dbData.0, dbData.1, account, networkType.networkId)
globalDBLock.unlock()
guard balance >= 0 else {
throw ZcashError.rustGetBalance(Int(account), lastErrorMessage(fallback: "Error getting total balance from account \(account)"))
@ -141,12 +151,14 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func getCurrentAddress(account: Int32) async throws -> UnifiedAddress {
globalDBLock.lock()
let addressCStr = zcashlc_get_current_address(
dbData.0,
dbData.1,
account,
networkType.networkId
)
globalDBLock.unlock()
guard let addressCStr else {
throw ZcashError.rustGetCurrentAddress(lastErrorMessage(fallback: "`getCurrentAddress` failed with unknown error"))
@ -162,12 +174,14 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func getNearestRewindHeight(height: Int32) async throws -> Int32 {
globalDBLock.lock()
let result = zcashlc_get_nearest_rewind_height(
dbData.0,
dbData.1,
height,
networkType.networkId
)
globalDBLock.unlock()
guard result > 0 else {
throw ZcashError.rustGetNearestRewindHeight(lastErrorMessage(fallback: "`getNearestRewindHeight` failed with unknown error"))
@ -177,12 +191,14 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func getNextAvailableAddress(account: Int32) async throws -> UnifiedAddress {
globalDBLock.lock()
let addressCStr = zcashlc_get_next_available_address(
dbData.0,
dbData.1,
account,
networkType.networkId
)
globalDBLock.unlock()
guard let addressCStr else {
throw ZcashError.rustGetNextAvailableAddress(lastErrorMessage(fallback: "`getNextAvailableAddress` failed with unknown error"))
@ -205,9 +221,11 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
var contiguousMemoBytes = ContiguousArray<UInt8>(MemoBytes.empty().bytes)
var success = false
globalDBLock.lock()
contiguousMemoBytes.withUnsafeMutableBufferPointer { memoBytePtr in
success = zcashlc_get_memo(dbData.0, dbData.1, txId.bytes, outputIndex, memoBytePtr.baseAddress, networkType.networkId)
}
globalDBLock.unlock()
guard success else { return nil }
@ -219,12 +237,14 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
throw ZcashError.rustGetTransparentBalanceNegativeAccount(Int(account))
}
globalDBLock.lock()
let balance = zcashlc_get_total_transparent_balance_for_account(
dbData.0,
dbData.1,
networkType.networkId,
account
)
globalDBLock.unlock()
guard balance >= 0 else {
throw ZcashError.rustGetTransparentBalance(
@ -237,6 +257,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func getVerifiedBalance(account: Int32) async throws -> Int64 {
globalDBLock.lock()
let balance = zcashlc_get_verified_balance(
dbData.0,
dbData.1,
@ -244,6 +265,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
networkType.networkId,
minimumConfirmations
)
globalDBLock.unlock()
guard balance >= 0 else {
throw ZcashError.rustGetVerifiedBalance(
@ -260,6 +282,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
throw ZcashError.rustGetVerifiedTransparentBalanceNegativeAccount(Int(account))
}
globalDBLock.lock()
let balance = zcashlc_get_verified_transparent_balance_for_account(
dbData.0,
dbData.1,
@ -267,6 +290,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
account,
minimumConfirmations
)
globalDBLock.unlock()
guard balance >= 0 else {
throw ZcashError.rustGetVerifiedTransparentBalance(
@ -279,7 +303,11 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func initDataDb(seed: [UInt8]?) async throws -> DbInitResult {
switch zcashlc_init_data_database(dbData.0, dbData.1, seed, UInt(seed?.count ?? 0), networkType.networkId) {
globalDBLock.lock()
let initResult = zcashlc_init_data_database(dbData.0, dbData.1, seed, UInt(seed?.count ?? 0), networkType.networkId)
globalDBLock.unlock()
switch initResult {
case 0: // ok
return DbInitResult.success
case 1:
@ -290,7 +318,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func initBlockMetadataDb() async throws {
globalDBLock.lock()
let result = zcashlc_init_block_metadata_db(fsBlockDbRoot.0, fsBlockDbRoot.1)
globalDBLock.unlock()
guard result else {
throw ZcashError.rustInitBlockMetadataDb(lastErrorMessage(fallback: "`initBlockMetadataDb` failed with unknown error"))
@ -346,7 +376,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
fsBlocks.initialize(to: meta)
globalDBLock.lock()
let res = zcashlc_write_block_metadata(fsBlockDbRoot.0, fsBlockDbRoot.1, fsBlocks)
globalDBLock.unlock()
guard res else {
throw ZcashError.rustWriteBlocksMetadata(lastErrorMessage(fallback: "`writeBlocksMetadata` failed with unknown error"))
@ -355,7 +387,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func latestCachedBlockHeight() async throws -> BlockHeight {
globalDBLock.lock()
let height = zcashlc_latest_cached_block_height(fsBlockDbRoot.0, fsBlockDbRoot.1)
globalDBLock.unlock()
if height >= 0 {
return BlockHeight(height)
@ -367,12 +401,14 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func listTransparentReceivers(account: Int32) async throws -> [TransparentAddress] {
globalDBLock.lock()
let encodedKeysPtr = zcashlc_list_transparent_receivers(
dbData.0,
dbData.1,
account,
networkType.networkId
)
globalDBLock.unlock()
guard let encodedKeysPtr else {
throw ZcashError.rustListTransparentReceivers(lastErrorMessage(fallback: "`listTransparentReceivers` failed with unknown error"))
@ -404,6 +440,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
value: Int64,
height: BlockHeight
) async throws {
globalDBLock.lock()
let result = zcashlc_put_utxo(
dbData.0,
dbData.1,
@ -416,6 +453,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
Int32(height),
networkType.networkId
)
globalDBLock.unlock()
guard result else {
throw ZcashError.rustPutUnspentTransparentOutput(lastErrorMessage(fallback: "`putUnspentTransparentOutput` failed with unknown error"))
@ -423,7 +461,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func rewindToHeight(height: Int32) async throws {
globalDBLock.lock()
let result = zcashlc_rewind_to_height(dbData.0, dbData.1, height, networkType.networkId)
globalDBLock.unlock()
guard result else {
throw ZcashError.rustRewindToHeight(height, lastErrorMessage(fallback: "`rewindToHeight` failed with unknown error"))
@ -431,7 +471,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func rewindCacheToHeight(height: Int32) async throws {
globalDBLock.lock()
let result = zcashlc_rewind_fs_block_cache_to_height(fsBlockDbRoot.0, fsBlockDbRoot.1, height)
globalDBLock.unlock()
guard result else {
throw ZcashError.rustRewindCacheToHeight(lastErrorMessage(fallback: "`rewindCacheToHeight` failed with unknown error"))
@ -486,7 +528,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
rootsPtr.initialize(to: roots)
globalDBLock.lock()
let res = zcashlc_put_sapling_subtree_roots(dbData.0, dbData.1, startIndex, rootsPtr, networkType.networkId)
globalDBLock.unlock()
guard res else {
throw ZcashError.rustPutSaplingSubtreeRoots(lastErrorMessage(fallback: "`putSaplingSubtreeRoots` failed with unknown error"))
@ -495,7 +539,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func updateChainTip(height: Int32) async throws {
globalDBLock.lock()
let result = zcashlc_update_chain_tip(dbData.0, dbData.1, height, networkType.networkId)
globalDBLock.unlock()
guard result else {
throw ZcashError.rustUpdateChainTip(lastErrorMessage(fallback: "`updateChainTip` failed with unknown error"))
@ -503,7 +549,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func fullyScannedHeight() async throws -> BlockHeight? {
globalDBLock.lock()
let height = zcashlc_fully_scanned_height(dbData.0, dbData.1, networkType.networkId)
globalDBLock.unlock()
if height >= 0 {
return BlockHeight(height)
@ -515,7 +563,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func maxScannedHeight() async throws -> BlockHeight? {
globalDBLock.lock()
let height = zcashlc_max_scanned_height(dbData.0, dbData.1, networkType.networkId)
globalDBLock.unlock()
if height >= 0 {
return BlockHeight(height)
@ -527,7 +577,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func getScanProgress() async throws -> ScanProgress? {
globalDBLock.lock()
let result = zcashlc_get_scan_progress(dbData.0, dbData.1, networkType.networkId)
globalDBLock.unlock()
if result.denominator == 0 {
switch result.numerator {
@ -542,7 +594,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func suggestScanRanges() async throws -> [ScanRange] {
globalDBLock.lock()
let scanRangesPtr = zcashlc_suggest_scan_ranges(dbData.0, dbData.1, networkType.networkId)
globalDBLock.unlock()
guard let scanRangesPtr else {
throw ZcashError.rustSuggestScanRanges(lastErrorMessage(fallback: "`suggestScanRanges` failed with unknown error"))
@ -570,7 +624,9 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
}
func scanBlocks(fromHeight: Int32, limit: UInt32 = 0) async throws {
globalDBLock.lock()
let result = zcashlc_scan_blocks(fsBlockDbRoot.0, fsBlockDbRoot.1, dbData.0, dbData.1, fromHeight, limit, networkType.networkId)
globalDBLock.unlock()
guard result != 0 else {
throw ZcashError.rustScanBlocks(lastErrorMessage(fallback: "`scanBlocks` failed with unknown error"))
@ -584,6 +640,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
) async throws -> Data {
var contiguousTxIdBytes = ContiguousArray<UInt8>([UInt8](repeating: 0x0, count: 32))
globalDBLock.lock()
let success = contiguousTxIdBytes.withUnsafeMutableBufferPointer { txIdBytePtr in
usk.bytes.withUnsafeBufferPointer { uskBuffer in
zcashlc_shield_funds(
@ -604,6 +661,7 @@ actor ZcashRustBackend: ZcashRustBackendWelding {
)
}
}
globalDBLock.unlock()
guard success else {
throw ZcashError.rustShieldFunds(lastErrorMessage(fallback: "`shieldFunds` failed with unknown error"))