diff --git a/.gitignore b/.gitignore index 3cf17d73..6ce73d07 100644 --- a/.gitignore +++ b/.gitignore @@ -76,7 +76,6 @@ Pods # do not commit generated libraries to this repo lib *.a -*.generated.swift env-vars.sh .vscode/ diff --git a/.swiftlint.yml b/.swiftlint.yml index e77b991a..f3f0a22a 100644 --- a/.swiftlint.yml +++ b/.swiftlint.yml @@ -15,6 +15,7 @@ excluded: - ZcashLightClientKitTests/Constants.generated.swift - build/ - docs/ + - Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift disabled_rules: - notification_center_detachment diff --git a/.swiftlint_tests.yml b/.swiftlint_tests.yml index 83611d47..aaa1c68f 100644 --- a/.swiftlint_tests.yml +++ b/.swiftlint_tests.yml @@ -11,6 +11,7 @@ excluded: - ZcashLightClientKitTests/Constants.generated.swift - build/ - docs/ + - Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift disabled_rules: - notification_center_detachment diff --git a/CHANGELOG.md b/CHANGELOG.md index 49fd4dd0..8500ad73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,12 @@ # unreleased +### [#888] Updates to layer between Swift and Rust + +This is mostly internal change. But it also touches the public API. + +`KeyDeriving` protocol is changed. And therefore `DerivationTool` is changed. `deriveUnifiedSpendingKey(seed:accountIndex:)` and +`deriveUnifiedFullViewingKey(from:)` methods are now async. `DerivationTool` offers alternatives for these methods. Alternatives are using either +closures or Combine. + ### [#469] ZcashRustBackendWelding to Async This is mostly internal change. But it also touches the public API. diff --git a/Example/ZcashLightClientSample/ZcashLightClientSample.xcodeproj/xcshareddata/xcschemes/All.xcscheme b/Example/ZcashLightClientSample/ZcashLightClientSample.xcodeproj/xcshareddata/xcschemes/All.xcscheme index 763fac8f..a8723e0a 100644 --- a/Example/ZcashLightClientSample/ZcashLightClientSample.xcodeproj/xcshareddata/xcschemes/All.xcscheme +++ b/Example/ZcashLightClientSample/ZcashLightClientSample.xcodeproj/xcshareddata/xcschemes/All.xcscheme @@ -104,6 +104,20 @@ ReferencedContainer = "container:../.."> + + + + 0 else { - let error = rustBackend.lastError() ?? RustWeldingError.genericError( - message: "unknown error getting nearest rewind height for height: \(height)" - ) + let nearestHeight: Int32 + do { + nearestHeight = try await rustBackend.getNearestRewindHeight(height: height) + } catch { await fail(error) return await context.completion(.failure(error)) } // FIXME: [#719] this should be done on the rust layer, https://github.com/zcash/ZcashLightClientKit/issues/719 let rewindHeight = max(Int32(nearestHeight - 1), Int32(config.walletBirthday)) - guard await rustBackend.rewindToHeight(dbData: config.dataDb, height: rewindHeight, networkType: self.config.network.networkType) else { - let error = rustBackend.lastError() ?? RustWeldingError.genericError(message: "unknown error rewinding to height \(height)") + + do { + try await rustBackend.rewindToHeight(height: rewindHeight) + } catch { await fail(error) return await context.completion(.failure(error)) } @@ -715,7 +694,7 @@ actor CompactBlockProcessor { func validateServer() async { do { let info = try await self.service.getInfo() - try Self.validateServerInfo( + try await Self.validateServerInfo( info, saplingActivation: self.config.saplingActivation, localNetwork: self.config.network, @@ -1055,14 +1034,10 @@ actor CompactBlockProcessor { self.consecutiveChainValidationErrors += 1 - let rewindResult = await rustBackend.rewindToHeight( - dbData: config.dataDb, - height: Int32(rewindHeight), - networkType: self.config.network.networkType - ) - - guard rewindResult else { - await fail(rustBackend.lastError() ?? RustWeldingError.genericError(message: "unknown error rewinding to height \(height)")) + do { + try await rustBackend.rewindToHeight(height: Int32(rewindHeight)) + } catch { + await fail(error) return } @@ -1205,11 +1180,7 @@ extension CompactBlockProcessor.State: Equatable { extension CompactBlockProcessor { func getUnifiedAddress(accountIndex: Int) async throws -> UnifiedAddress { - try await rustBackend.getCurrentAddress( - dbData: config.dataDb, - account: Int32(accountIndex), - networkType: config.network.networkType - ) + try await rustBackend.getCurrentAddress(account: Int32(accountIndex)) } func getSaplingAddress(accountIndex: Int) async throws -> SaplingAddress { @@ -1227,18 +1198,10 @@ extension CompactBlockProcessor { return WalletBalance( verified: Zatoshi( - try await rustBackend.getVerifiedTransparentBalance( - dbData: config.dataDb, - account: Int32(accountIndex), - networkType: config.network.networkType - ) + try await rustBackend.getVerifiedTransparentBalance(account: Int32(accountIndex)) ), total: Zatoshi( - try await rustBackend.getTransparentBalance( - dbData: config.dataDb, - account: Int32(accountIndex), - networkType: config.network.networkType - ) + try await rustBackend.getTransparentBalance(account: Int32(accountIndex)) ) ) } @@ -1269,19 +1232,15 @@ extension CompactBlockProcessor { var skipped: [UnspentTransactionOutputEntity] = [] for utxo in utxos { do { - if try await rustBackend.putUnspentTransparentOutput( - dbData: dataDb, + try await rustBackend.putUnspentTransparentOutput( txid: utxo.txid.bytes, index: utxo.index, script: utxo.script.bytes, value: Int64(utxo.valueZat), - height: utxo.height, - networkType: self.config.network.networkType - ) { - refreshed.append(utxo) - } else { - skipped.append(utxo) - } + height: utxo.height + ) + + refreshed.append(utxo) } catch { logger.info("failed to put utxo - error: \(error)") skipped.append(utxo) @@ -1389,7 +1348,7 @@ extension CompactBlockProcessor { downloaderService: BlockDownloaderService, transactionRepository: TransactionRepository, config: Configuration, - rustBackend: ZcashRustBackendWelding.Type, + rustBackend: ZcashRustBackendWelding, internalSyncProgress: InternalSyncProgress ) async throws -> CompactBlockProcessor.NextState { // It should be ok to not create new Task here because this method is already async. But for some reason something not good happens @@ -1397,7 +1356,7 @@ extension CompactBlockProcessor { let task = Task(priority: .userInitiated) { let info = try await service.getInfo() - try CompactBlockProcessor.validateServerInfo( + try await CompactBlockProcessor.validateServerInfo( info, saplingActivation: config.saplingActivation, localNetwork: config.network, diff --git a/Sources/ZcashLightClientKit/Block/Enhance/BlockEnhancer.swift b/Sources/ZcashLightClientKit/Block/Enhance/BlockEnhancer.swift index ad209a75..c5081534 100644 --- a/Sources/ZcashLightClientKit/Block/Enhance/BlockEnhancer.swift +++ b/Sources/ZcashLightClientKit/Block/Enhance/BlockEnhancer.swift @@ -14,20 +14,14 @@ enum BlockEnhancerError: Error { case txIdNotFound(txId: Data) } -struct BlockEnhancerConfig { - let dataDb: URL - let networkType: NetworkType -} - protocol BlockEnhancer { func enhance(at range: CompactBlockRange, didEnhance: (EnhancementProgress) async -> Void) async throws -> [ZcashTransaction.Overview] } struct BlockEnhancerImpl { let blockDownloaderService: BlockDownloaderService - let config: BlockEnhancerConfig let internalSyncProgress: InternalSyncProgress - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let transactionRepository: TransactionRepository let metrics: SDKMetrics let logger: Logger @@ -41,17 +35,13 @@ struct BlockEnhancerImpl { let block = String(describing: transaction.minedHeight) logger.debug("Decrypting and storing transaction id: \(transactionID) block: \(block)") - let decryptionResult = await rustBackend.decryptAndStoreTransaction( - dbData: config.dataDb, - txBytes: fetchedTransaction.raw.bytes, - minedHeight: Int32(fetchedTransaction.minedHeight), - networkType: config.networkType - ) - - guard decryptionResult else { - throw BlockEnhancerError.decryptError( - error: rustBackend.lastError() ?? .genericError(message: "`decryptAndStoreTransaction` failed. No message available") + do { + try await rustBackend.decryptAndStoreTransaction( + txBytes: fetchedTransaction.raw.bytes, + minedHeight: Int32(fetchedTransaction.minedHeight) ) + } catch { + throw BlockEnhancerError.decryptError(error: error) } let confirmedTx: ZcashTransaction.Overview diff --git a/Sources/ZcashLightClientKit/Block/FetchUnspentTxOutputs/UTXOFetcher.swift b/Sources/ZcashLightClientKit/Block/FetchUnspentTxOutputs/UTXOFetcher.swift index 80cc3096..dd3f9b1d 100644 --- a/Sources/ZcashLightClientKit/Block/FetchUnspentTxOutputs/UTXOFetcher.swift +++ b/Sources/ZcashLightClientKit/Block/FetchUnspentTxOutputs/UTXOFetcher.swift @@ -13,8 +13,6 @@ enum UTXOFetcherError: Error { } struct UTXOFetcherConfig { - let dataDb: URL - let networkType: NetworkType let walletBirthdayProvider: () async -> BlockHeight } @@ -27,7 +25,7 @@ struct UTXOFetcherImpl { let blockDownloaderService: BlockDownloaderService let config: UTXOFetcherConfig let internalSyncProgress: InternalSyncProgress - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let metrics: SDKMetrics let logger: Logger } @@ -41,11 +39,7 @@ extension UTXOFetcherImpl: UTXOFetcher { var tAddresses: [TransparentAddress] = [] for account in accounts { - tAddresses += try await rustBackend.listTransparentReceivers( - dbData: config.dataDb, - account: Int32(account), - networkType: config.networkType - ) + tAddresses += try await rustBackend.listTransparentReceivers(account: Int32(account)) } var utxos: [UnspentTransactionOutputEntity] = [] @@ -64,19 +58,15 @@ extension UTXOFetcherImpl: UTXOFetcher { let startTime = Date() for utxo in utxos { do { - if try await rustBackend.putUnspentTransparentOutput( - dbData: config.dataDb, + try await rustBackend.putUnspentTransparentOutput( txid: utxo.txid.bytes, index: utxo.index, script: utxo.script.bytes, value: Int64(utxo.valueZat), - height: utxo.height, - networkType: config.networkType - ) { - refreshed.append(utxo) - } else { - skipped.append(utxo) - } + height: utxo.height + ) + + refreshed.append(utxo) await internalSyncProgress.set(utxo.height, .latestUTXOFetchedHeight) } catch { diff --git a/Sources/ZcashLightClientKit/Block/FilesystemStorage/FSCompactBlockRepository.swift b/Sources/ZcashLightClientKit/Block/FilesystemStorage/FSCompactBlockRepository.swift index 3192242c..fe9432fe 100644 --- a/Sources/ZcashLightClientKit/Block/FilesystemStorage/FSCompactBlockRepository.swift +++ b/Sources/ZcashLightClientKit/Block/FilesystemStorage/FSCompactBlockRepository.swift @@ -52,7 +52,10 @@ extension FSCompactBlockRepository: CompactBlockRepository { try fileManager.createDirectory(at: blocksDirectory, withIntermediateDirectories: true) } - guard try await self.metadataStore.initFsBlockDbRoot(self.fsBlockDbRoot) else { + do { + try await self.metadataStore.initFsBlockDbRoot() + } catch { + logger.error("Blocks metadata store init failed with error: \(error)") throw CompactBlockRepositoryError.failedToInitializeCache } } @@ -210,12 +213,12 @@ extension FSBlockFileWriter { struct FSMetadataStore { var saveBlocksMeta: ([ZcashCompactBlock]) async throws -> Void var rewindToHeight: (BlockHeight) async throws -> Void - var initFsBlockDbRoot: (URL) async throws -> Bool + var initFsBlockDbRoot: () async throws -> Void var latestHeight: () async -> BlockHeight } extension FSMetadataStore { - static func live(fsBlockDbRoot: URL, rustBackend: ZcashRustBackendWelding.Type, logger: Logger) -> FSMetadataStore { + static func live(fsBlockDbRoot: URL, rustBackend: ZcashRustBackendWelding, logger: Logger) -> FSMetadataStore { FSMetadataStore { blocks in try await FSMetadataStore.saveBlocksMeta( blocks, @@ -224,13 +227,15 @@ extension FSMetadataStore { logger: logger ) } rewindToHeight: { height in - guard await rustBackend.rewindCacheToHeight(fsBlockDbRoot: fsBlockDbRoot, height: Int32(height)) else { + do { + try await rustBackend.rewindCacheToHeight(height: Int32(height)) + } catch { throw CompactBlockRepositoryError.failedToRewind(height) } - } initFsBlockDbRoot: { dbRootURL in - try await rustBackend.initBlockMetadataDb(fsBlockDbRoot: dbRootURL) + } initFsBlockDbRoot: { + try await rustBackend.initBlockMetadataDb() } latestHeight: { - await rustBackend.latestCachedBlockHeight(fsBlockDbRoot: fsBlockDbRoot) + await rustBackend.latestCachedBlockHeight() } } } @@ -244,15 +249,13 @@ extension FSMetadataStore { static func saveBlocksMeta( _ blocks: [ZcashCompactBlock], fsBlockDbRoot: URL, - rustBackend: ZcashRustBackendWelding.Type, + rustBackend: ZcashRustBackendWelding, logger: Logger ) async throws { guard !blocks.isEmpty else { return } do { - guard try await rustBackend.writeBlocksMetadata(fsBlockDbRoot: fsBlockDbRoot, blocks: blocks) else { - throw CompactBlockRepositoryError.failedToWriteMetadata - } + try await rustBackend.writeBlocksMetadata(blocks: blocks) } catch { logger.error("Failed to write metadata with error: \(error)") throw CompactBlockRepositoryError.failedToWriteMetadata diff --git a/Sources/ZcashLightClientKit/Block/SaplingParameters/SaplingParametersHandler.swift b/Sources/ZcashLightClientKit/Block/SaplingParameters/SaplingParametersHandler.swift index 08bb8d85..09746b6d 100644 --- a/Sources/ZcashLightClientKit/Block/SaplingParameters/SaplingParametersHandler.swift +++ b/Sources/ZcashLightClientKit/Block/SaplingParameters/SaplingParametersHandler.swift @@ -8,8 +8,6 @@ import Foundation struct SaplingParametersHandlerConfig { - let dataDb: URL - let networkType: NetworkType let outputParamsURL: URL let spendParamsURL: URL let saplingParamsSourceURL: SaplingParamsSourceURL @@ -21,7 +19,7 @@ protocol SaplingParametersHandler { struct SaplingParametersHandlerImpl { let config: SaplingParametersHandlerConfig - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let logger: Logger } @@ -30,16 +28,8 @@ extension SaplingParametersHandlerImpl: SaplingParametersHandler { try Task.checkCancellation() do { - let totalShieldedBalance = try await rustBackend.getBalance( - dbData: config.dataDb, - account: Int32(0), - networkType: config.networkType - ) - let totalTransparentBalance = try await rustBackend.getTransparentBalance( - dbData: config.dataDb, - account: Int32(0), - networkType: config.networkType - ) + let totalShieldedBalance = try await rustBackend.getBalance(account: Int32(0)) + let totalTransparentBalance = try await rustBackend.getTransparentBalance(account: Int32(0)) // Download Sapling parameters only if sapling funds are detected. guard totalShieldedBalance > 0 || totalTransparentBalance > 0 else { return } diff --git a/Sources/ZcashLightClientKit/Block/Scan/BlockScanner.swift b/Sources/ZcashLightClientKit/Block/Scan/BlockScanner.swift index f4f9ff34..a49ca4c2 100644 --- a/Sources/ZcashLightClientKit/Block/Scan/BlockScanner.swift +++ b/Sources/ZcashLightClientKit/Block/Scan/BlockScanner.swift @@ -8,8 +8,6 @@ import Foundation struct BlockScannerConfig { - let fsBlockCacheRoot: URL - let dataDB: URL let networkType: NetworkType let scanningBatchSize: Int } @@ -20,7 +18,7 @@ protocol BlockScanner { struct BlockScannerImpl { let config: BlockScannerConfig - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let transactionRepository: TransactionRepository let metrics: SDKMetrics let logger: Logger @@ -42,19 +40,16 @@ extension BlockScannerImpl: BlockScanner { let previousScannedHeight = lastScannedHeight // TODO: [#576] remove this arbitrary batch size https://github.com/zcash/ZcashLightClientKit/issues/576 - let batchSize = scanBatchSize(startScanHeight: previousScannedHeight + 1, network: self.config.networkType) + let batchSize = scanBatchSize(startScanHeight: previousScannedHeight + 1, network: config.networkType) let scanStartTime = Date() - guard await self.rustBackend.scanBlocks( - fsBlockDbRoot: config.fsBlockCacheRoot, - dbData: config.dataDB, - limit: batchSize, - networkType: config.networkType - ) else { - let error: Error = rustBackend.lastError() ?? CompactBlockProcessorError.unknown + do { + try await self.rustBackend.scanBlocks(limit: batchSize) + } catch { logger.debug("block scanning failed with error: \(String(describing: error))") throw error } + let scanFinishTime = Date() lastScannedHeight = try await transactionRepository.lastScannedHeight() diff --git a/Sources/ZcashLightClientKit/Block/Validate/BlockValidator.swift b/Sources/ZcashLightClientKit/Block/Validate/BlockValidator.swift index 9a78ff15..1936f621 100644 --- a/Sources/ZcashLightClientKit/Block/Validate/BlockValidator.swift +++ b/Sources/ZcashLightClientKit/Block/Validate/BlockValidator.swift @@ -14,20 +14,13 @@ enum BlockValidatorError: Error { case failedWithUnknownError } -struct BlockValidatorConfig { - let fsBlockCacheRoot: URL - let dataDB: URL - let networkType: NetworkType -} - protocol BlockValidator { /// Validate all the downloaded blocks that haven't been yet validated. func validate() async throws } struct BlockValidatorImpl { - let config: BlockValidatorConfig - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let metrics: SDKMetrics let logger: Logger } @@ -37,14 +30,24 @@ extension BlockValidatorImpl: BlockValidator { try Task.checkCancellation() let startTime = Date() - let result = await rustBackend.validateCombinedChain( - fsBlockDbRoot: config.fsBlockCacheRoot, - dbData: config.dataDB, - networkType: config.networkType, - limit: 0 - ) - let finishTime = Date() + do { + try await rustBackend.validateCombinedChain(limit: 0) + pushProgressReport(startTime: startTime, finishTime: Date()) + logger.debug("validateChainFinished") + } catch { + pushProgressReport(startTime: startTime, finishTime: Date()) + switch error { + case let RustWeldingError.invalidChain(upperBound): + throw BlockValidatorError.validationFailed(height: BlockHeight(upperBound)) + + default: + throw BlockValidatorError.failedWithError(error) + } + } + } + + private func pushProgressReport(startTime: Date, finishTime: Date) { metrics.pushProgressReport( progress: BlockProgress(startHeight: 0, targetHeight: 0, progressHeight: 0), start: startTime, @@ -52,24 +55,5 @@ extension BlockValidatorImpl: BlockValidator { batchSize: 0, operation: .validateBlocks ) - - switch result { - case 0: - let rustError = rustBackend.lastError() - logger.debug("Block validation failed with error: \(String(describing: rustError))") - if let rustError { - throw BlockValidatorError.failedWithError(rustError) - } else { - throw BlockValidatorError.failedWithUnknownError - } - - case ZcashRustBackendWeldingConstants.validChain: - logger.debug("validateChainFinished") - return - - default: - logger.debug("Block validation failed at height: \(result)") - throw BlockValidatorError.validationFailed(height: BlockHeight(result)) - } } } diff --git a/Sources/ZcashLightClientKit/CombineSynchronizer.swift b/Sources/ZcashLightClientKit/CombineSynchronizer.swift index 5d35ceaa..c25b9399 100644 --- a/Sources/ZcashLightClientKit/CombineSynchronizer.swift +++ b/Sources/ZcashLightClientKit/CombineSynchronizer.swift @@ -8,13 +8,6 @@ import Combine import Foundation -/* These aliases are here to just make the API easier to read. */ - -// Publisher which emitts completed or error. No value is emitted. -public typealias CompletablePublisher = AnyPublisher -// Publisher that either emits one value and then finishes or it emits error. -public typealias SinglePublisher = AnyPublisher - /// This defines a Combine-based API for the SDK. It's expected that the implementation of this protocol is only a very thin layer that translates /// async API defined in `Synchronizer` to Combine-based API. And it doesn't do anything else. It's here so each client can choose the API that suits /// its case the best. diff --git a/Sources/ZcashLightClientKit/DAO/PendingTransactionDao.swift b/Sources/ZcashLightClientKit/DAO/PendingTransactionDao.swift index 17520ca1..a51ce521 100644 --- a/Sources/ZcashLightClientKit/DAO/PendingTransactionDao.swift +++ b/Sources/ZcashLightClientKit/DAO/PendingTransactionDao.swift @@ -28,7 +28,7 @@ struct PendingTransaction: PendingTransactionEntity, Decodable, Encodable { case rawTransactionId = "txid" case fee } - + var recipient: PendingTransactionRecipient var accountIndex: Int var minedHeight: BlockHeight diff --git a/Sources/ZcashLightClientKit/Initializer.swift b/Sources/ZcashLightClientKit/Initializer.swift index 7ec08cdc..5c07ec66 100644 --- a/Sources/ZcashLightClientKit/Initializer.swift +++ b/Sources/ZcashLightClientKit/Initializer.swift @@ -115,7 +115,6 @@ public class Initializer { // This is used to uniquely identify instance of the SDKSynchronizer. It's used when checking if the Alias is already used or not. let id = UUID() - let rustBackend: ZcashRustBackendWelding.Type let alias: ZcashSynchronizerAlias let endpoint: LightWalletEndpoint let fsBlockDbRoot: URL @@ -131,6 +130,7 @@ public class Initializer { let blockDownloaderService: BlockDownloaderService let network: ZcashNetwork let logger: Logger + let rustBackend: ZcashRustBackendWelding /// The effective birthday of the wallet based on the height provided when initializing and the checkpoints available on this SDK. /// @@ -180,8 +180,16 @@ public class Initializer { let (updatedURLs, parsingError) = Self.tryToUpdateURLs(with: alias, urls: urls) let logger = OSLogger(logLevel: logLevel, alias: alias) + let rustBackend = ZcashRustBackend( + dbData: updatedURLs.dataDbURL, + fsBlockDbRoot: updatedURLs.fsBlockDbRoot, + spendParamsPath: updatedURLs.spendParamsURL, + outputParamsPath: updatedURLs.outputParamsURL, + networkType: network.networkType + ) + self.init( - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, network: network, cacheDbURL: cacheDbURL, urls: updatedURLs, @@ -198,7 +206,7 @@ public class Initializer { fsBlockDbRoot: updatedURLs.fsBlockDbRoot, metadataStore: .live( fsBlockDbRoot: updatedURLs.fsBlockDbRoot, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -216,7 +224,7 @@ public class Initializer { /// /// !!! It's expected that URLs put here are already update with the Alias. init( - rustBackend: ZcashRustBackendWelding.Type, + rustBackend: ZcashRustBackendWelding, network: ZcashNetwork, cacheDbURL: URL?, urls: URLs, @@ -341,7 +349,7 @@ public class Initializer { } do { - if case .seedRequired = try await rustBackend.initDataDb(dbData: dataDbURL, seed: seed, networkType: network.networkType) { + if case .seedRequired = try await rustBackend.initDataDb(seed: seed) { return .seedRequired } } catch { @@ -351,12 +359,10 @@ public class Initializer { let checkpoint = Checkpoint.birthday(with: walletBirthday, network: network) do { try await rustBackend.initBlocksTable( - dbData: dataDbURL, height: Int32(checkpoint.height), hash: checkpoint.hash, time: checkpoint.time, - saplingTree: checkpoint.saplingTree, - networkType: network.networkType + saplingTree: checkpoint.saplingTree ) } catch RustWeldingError.dataDbNotEmpty { // this is fine @@ -367,11 +373,7 @@ public class Initializer { self.walletBirthday = checkpoint.height do { - try await rustBackend.initAccountsTable( - dbData: dataDbURL, - ufvks: viewingKeys, - networkType: network.networkType - ) + try await rustBackend.initAccountsTable(ufvks: viewingKeys) } catch RustWeldingError.dataDbNotEmpty { // this is fine } catch RustWeldingError.malformedStringInput { @@ -395,14 +397,18 @@ public class Initializer { checks if the provided address is a valid sapling address */ public func isValidSaplingAddress(_ address: String) -> Bool { - rustBackend.isValidSaplingAddress(address, networkType: network.networkType) + DerivationTool(networkType: network.networkType).isValidSaplingAddress(address) } /** checks if the provided address is a transparent zAddress */ public func isValidTransparentAddress(_ address: String) -> Bool { - rustBackend.isValidTransparentAddress(address, networkType: network.networkType) + DerivationTool(networkType: network.networkType).isValidTransparentAddress(address) + } + + public func makeDerivationTool() -> DerivationTool { + return DerivationTool(networkType: network.networkType) } } diff --git a/Sources/ZcashLightClientKit/Model/WalletTypes.swift b/Sources/ZcashLightClientKit/Model/WalletTypes.swift index 71647549..06e49060 100644 --- a/Sources/ZcashLightClientKit/Model/WalletTypes.swift +++ b/Sources/ZcashLightClientKit/Model/WalletTypes.swift @@ -63,7 +63,7 @@ public struct UnifiedFullViewingKey: Equatable, StringEncoded, Undescribable { /// - Throws: `KeyEncodingError.invalidEncoding`when the provided encoding is /// found to be invalid public init(encoding: String, account: UInt32, network: NetworkType) throws { - guard DerivationTool.rustwelding.isValidUnifiedFullViewingKey(encoding, networkType: network) else { + guard DerivationTool(networkType: network).isValidUnifiedFullViewingKey(encoding) else { throw KeyEncodingError.invalidEncoding } @@ -85,7 +85,7 @@ public struct SaplingExtendedFullViewingKey: Equatable, StringEncoded, Undescrib /// - Throws: `KeyEncodingError.invalidEncoding`when the provided encoding is /// found to be invalid public init(encoding: String, network: NetworkType) throws { - guard DerivationTool.rustwelding.isValidSaplingExtendedFullViewingKey(encoding, networkType: network) else { + guard ZcashKeyDerivationBackend(networkType: network).isValidSaplingExtendedFullViewingKey(encoding) else { throw KeyEncodingError.invalidEncoding } self.encoding = encoding @@ -174,6 +174,8 @@ public struct SaplingAddress: Equatable, StringEncoded { } public struct UnifiedAddress: Equatable, StringEncoded { + let networkType: NetworkType + public enum Errors: Error { case couldNotExtractTypecodes } @@ -212,6 +214,7 @@ public struct UnifiedAddress: Equatable, StringEncoded { /// - Throws: `KeyEncodingError.invalidEncoding`when the provided encoding is /// found to be invalid public init(encoding: String, network: NetworkType) throws { + networkType = network guard DerivationTool(networkType: network).isValidUnifiedAddress(encoding) else { throw KeyEncodingError.invalidEncoding } @@ -224,7 +227,7 @@ public struct UnifiedAddress: Equatable, StringEncoded { /// couldn't be extracted public func availableReceiverTypecodes() throws -> [UnifiedAddress.ReceiverTypecodes] { do { - return try DerivationTool.receiverTypecodesFromUnifiedAddress(self) + return try DerivationTool(networkType: networkType).receiverTypecodesFromUnifiedAddress(self) } catch { throw Errors.couldNotExtractTypecodes } @@ -272,7 +275,7 @@ public enum Recipient: Equatable, StringEncoded { metadata.networkType) case .p2sh: return (.transparent(TransparentAddress(validatedEncoding: encoded)), metadata.networkType) case .sapling: return (.sapling(SaplingAddress(validatedEncoding: encoded)), metadata.networkType) - case .unified: return (.unified(UnifiedAddress(validatedEncoding: encoded)), metadata.networkType) + case .unified: return (.unified(UnifiedAddress(validatedEncoding: encoded, networkType: metadata.networkType)), metadata.networkType) } } } diff --git a/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackend.swift b/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackend.swift new file mode 100644 index 00000000..195102ad --- /dev/null +++ b/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackend.swift @@ -0,0 +1,225 @@ +// +// ZcashKeyDerivationBackend.swift +// +// +// Created by Francisco Gindre on 4/7/23. +// + +import Foundation +import libzcashlc + +struct ZcashKeyDerivationBackend: ZcashKeyDerivationBackendWelding { + let networkType: NetworkType + + init(networkType: NetworkType) { + self.networkType = networkType + } + + // MARK: Address metadata and validation + static func getAddressMetadata(_ address: String) -> AddressMetadata? { + var networkId: UInt32 = 0 + var addrId: UInt32 = 0 + guard zcashlc_get_address_metadata( + [CChar](address.utf8CString), + &networkId, + &addrId + ) else { + return nil + } + + guard + let network = NetworkType.forNetworkId(networkId), + let addrType = AddressType.forId(addrId) + else { + return nil + } + + return AddressMetadata(network: network, addrType: addrType) + } + + func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] { + guard !address.containsCStringNullBytesBeforeStringEnding() else { + throw RustWeldingError.invalidInput(message: "`address` contains null bytes.") + } + + var len = UInt(0) + + guard let typecodesPointer = zcashlc_get_typecodes_for_unified_address_receivers( + [CChar](address.utf8CString), + &len + ), len > 0 + else { + throw RustWeldingError.malformedStringInput + } + + var typecodes: [UInt32] = [] + + for typecodeIndex in 0 ..< Int(len) { + let pointer = typecodesPointer.advanced(by: typecodeIndex) + + typecodes.append(pointer.pointee) + } + + defer { + zcashlc_free_typecodes(typecodesPointer, len) + } + + return typecodes + } + + func isValidSaplingAddress(_ address: String) -> Bool { + guard !address.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_shielded_address([CChar](address.utf8CString), networkType.networkId) + } + + func isValidSaplingExtendedFullViewingKey(_ key: String) -> Bool { + guard !key.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_viewing_key([CChar](key.utf8CString), networkType.networkId) + } + + func isValidSaplingExtendedSpendingKey(_ key: String) -> Bool { + guard !key.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_sapling_extended_spending_key([CChar](key.utf8CString), networkType.networkId) + } + + func isValidTransparentAddress(_ address: String) -> Bool { + guard !address.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_transparent_address([CChar](address.utf8CString), networkType.networkId) + } + + func isValidUnifiedAddress(_ address: String) -> Bool { + guard !address.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_unified_address([CChar](address.utf8CString), networkType.networkId) + } + + func isValidUnifiedFullViewingKey(_ key: String) -> Bool { + guard !key.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_unified_full_viewing_key([CChar](key.utf8CString), networkType.networkId) + } + + // MARK: Address Derivation + + func deriveUnifiedSpendingKey( + from seed: [UInt8], + accountIndex: Int32 + ) async throws -> UnifiedSpendingKey { + let binaryKeyPtr = seed.withUnsafeBufferPointer { seedBufferPtr in + return zcashlc_derive_spending_key( + seedBufferPtr.baseAddress, + UInt(seed.count), + accountIndex, + networkType.networkId + ) + } + + defer { zcashlc_free_binary_key(binaryKeyPtr) } + + guard let binaryKey = binaryKeyPtr?.pointee else { + throw lastError() ?? .genericError(message: "No error message available") + } + + return binaryKey.unsafeToUnifiedSpendingKey(network: networkType) + } + + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) async throws -> UnifiedFullViewingKey { + let extfvk = try spendingKey.bytes.withUnsafeBufferPointer { uskBufferPtr -> UnsafeMutablePointer in + guard let extfvk = zcashlc_spending_key_to_full_viewing_key( + uskBufferPtr.baseAddress, + UInt(spendingKey.bytes.count), + networkType.networkId + ) else { + throw lastError() ?? .genericError(message: "No error message available") + } + + return extfvk + } + + defer { zcashlc_string_free(extfvk) } + + guard let derived = String(validatingUTF8: extfvk) else { + throw RustWeldingError.unableToDeriveKeys + } + + return UnifiedFullViewingKey(validatedEncoding: derived, account: spendingKey.account) + } + + func getSaplingReceiver(for uAddr: UnifiedAddress) throws -> SaplingAddress { + guard let saplingCStr = zcashlc_get_sapling_receiver_for_unified_address( + [CChar](uAddr.encoding.utf8CString) + ) else { + throw KeyDerivationErrors.invalidUnifiedAddress + } + + defer { zcashlc_string_free(saplingCStr) } + + guard let saplingReceiverStr = String(validatingUTF8: saplingCStr) else { + throw KeyDerivationErrors.receiverNotFound + } + + return SaplingAddress(validatedEncoding: saplingReceiverStr) + } + + func getTransparentReceiver(for uAddr: UnifiedAddress) throws -> TransparentAddress { + guard let transparentCStr = zcashlc_get_transparent_receiver_for_unified_address( + [CChar](uAddr.encoding.utf8CString) + ) else { + throw KeyDerivationErrors.invalidUnifiedAddress + } + + defer { zcashlc_string_free(transparentCStr) } + + guard let transparentReceiverStr = String(validatingUTF8: transparentCStr) else { + throw KeyDerivationErrors.receiverNotFound + } + + return TransparentAddress(validatedEncoding: transparentReceiverStr) + } + + // MARK: Error Handling + + private func lastError() -> RustWeldingError? { + defer { zcashlc_clear_last_error() } + + guard let message = getLastError() else { + return nil + } + + if message.contains("couldn't load Sapling spend parameters") { + return RustWeldingError.saplingSpendParametersNotFound + } else if message.contains("is not empty") { + return RustWeldingError.dataDbNotEmpty + } + + return RustWeldingError.genericError(message: message) + } + + private func getLastError() -> String? { + let errorLen = zcashlc_last_error_length() + if errorLen > 0 { + let error = UnsafeMutablePointer.allocate(capacity: Int(errorLen)) + zcashlc_error_message_utf8(error, errorLen) + zcashlc_clear_last_error() + return String(validatingUTF8: error) + } else { + return nil + } + } +} diff --git a/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackendWelding.swift b/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackendWelding.swift new file mode 100644 index 00000000..b774cce9 --- /dev/null +++ b/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackendWelding.swift @@ -0,0 +1,82 @@ +// +// ZcashKeyDerivationBackendWelding.swift +// +// +// Created by Michal Fousek on 11.04.2023. +// + +import Foundation + +protocol ZcashKeyDerivationBackendWelding { + /// The network type this `ZcashKeyDerivationBackendWelding` implementation is for + var networkType: NetworkType { get } + + /// Returns the network and address type for the given Zcash address string, + /// if the string represents a valid Zcash address. + /// - Note: not `NetworkType` bound + static func getAddressMetadata(_ address: String) -> AddressMetadata? + + /// Obtains the available receiver typecodes for the given String encoded Unified Address + /// - Parameter address: public key represented as a String + /// - Returns the `[UInt32]` that compose the given UA + /// - Throws `RustWeldingError.invalidInput(message: String)` when the UA is either invalid or malformed + /// - Note: not `NetworkType` bound + func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] + + /// Validates the if the given string is a valid Sapling Address + /// - Parameter address: UTF-8 encoded String to validate + /// - Returns: true when the address is valid. Returns false in any other case + /// - Throws: Error when the provided address belongs to another network + func isValidSaplingAddress(_ address: String) -> Bool + + /// Validates the if the given string is a valid Sapling Extended Full Viewing Key + /// - Parameter key: UTF-8 encoded String to validate + /// - Returns: `true` when the Sapling Extended Full Viewing Key is valid. `false` in any other case + /// - Throws: Error when there's another problem not related to validity of the string in question + func isValidSaplingExtendedFullViewingKey(_ key: String) -> Bool + + /// Validates the if the given string is a valid Sapling Extended Spending Key + /// - Returns: `true` when the Sapling Extended Spending Key is valid, false in any other case. + /// - Throws: Error when the key is semantically valid but it belongs to another network + /// - parameter key: String encoded Extended Spending Key + func isValidSaplingExtendedSpendingKey(_ key: String) -> Bool + + /// Validates the if the given string is a valid Transparent Address + /// - Parameter address: UTF-8 encoded String to validate + /// - Returns: true when the address is valid and transparent. false in any other case + func isValidTransparentAddress(_ address: String) -> Bool + + /// validates whether a string encoded address is a valid Unified Address. + /// - Parameter address: UTF-8 encoded String to validate + /// - Returns: true when the address is valid and transparent. false in any other case + func isValidUnifiedAddress(_ address: String) -> Bool + + /// verifies that the given string-encoded `UnifiedFullViewingKey` is valid. + /// - Parameter ufvk: UTF-8 encoded String to validate + /// - Returns: true when the encoded string is a valid UFVK. false in any other case + func isValidUnifiedFullViewingKey(_ ufvk: String) -> Bool + + /// Derives and returns a unified spending key from the given seed for the given account ID. + /// Returns the binary encoding of the spending key. The caller should manage the memory of (and store, if necessary) the returned spending key in a secure fashion. + /// - Parameter seed: a Byte Array with the seed + /// - Parameter accountIndex:account index that the key can spend from + func deriveUnifiedSpendingKey(from seed: [UInt8], accountIndex: Int32) async throws -> UnifiedSpendingKey + + /// Derives a `UnifiedFullViewingKey` from a `UnifiedSpendingKey` + /// - Parameter spendingKey: the `UnifiedSpendingKey` to derive from + /// - Throws: `RustWeldingError.unableToDeriveKeys` if the SDK couldn't derive the UFVK. + /// - Returns: the derived `UnifiedFullViewingKey` + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) async throws -> UnifiedFullViewingKey + + /// Returns the Sapling receiver within the given Unified Address, if any. + /// - Parameter uAddr: a `UnifiedAddress` + /// - Returns a `SaplingAddress` if any + /// - Throws `receiverNotFound` when the receiver is not found. `invalidUnifiedAddress` if the UA provided is not valid + func getSaplingReceiver(for uAddr: UnifiedAddress) throws -> SaplingAddress + + /// Returns the transparent receiver within the given Unified Address, if any. + /// - parameter uAddr: a `UnifiedAddress` + /// - Returns a `TransparentAddress` if any + /// - Throws `receiverNotFound` when the receiver is not found. `invalidUnifiedAddress` if the UA provided is not valid + func getTransparentReceiver(for uAddr: UnifiedAddress) throws -> TransparentAddress +} diff --git a/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift b/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift index bfefe45e..78046ccf 100644 --- a/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift +++ b/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift @@ -9,12 +9,37 @@ import Foundation import libzcashlc -class ZcashRustBackend: ZcashRustBackendWelding { - static let minimumConfirmations: UInt32 = 10 - static let useZIP317Fees = false - static func createAccount(dbData: URL, seed: [UInt8], networkType: NetworkType) async throws -> UnifiedSpendingKey { - let dbData = dbData.osStr() +actor ZcashRustBackend: ZcashRustBackendWelding { + let minimumConfirmations: UInt32 = 10 + let useZIP317Fees = false + let dbData: (String, UInt) + let fsBlockDbRoot: (String, UInt) + let spendParamsPath: (String, UInt) + let outputParamsPath: (String, UInt) + let keyDeriving: ZcashKeyDerivationBackendWelding + + nonisolated let networkType: NetworkType + + /// Creates instance of `ZcashRustBackend`. + /// - Parameters: + /// - dbData: `URL` pointing to file where data database will be. + /// - fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. + /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename + /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. + /// - spendParamsPath: `URL` pointing to spend parameters file. + /// - outputParamsPath: `URL` pointing to output parameters file. + /// - networkType: Network type to use. + init(dbData: URL, fsBlockDbRoot: URL, spendParamsPath: URL, outputParamsPath: URL, networkType: NetworkType) { + self.dbData = dbData.osStr() + self.fsBlockDbRoot = fsBlockDbRoot.osPathStr() + self.spendParamsPath = spendParamsPath.osPathStr() + self.outputParamsPath = outputParamsPath.osPathStr() + self.networkType = networkType + self.keyDeriving = ZcashKeyDerivationBackend(networkType: networkType) + } + + func createAccount(seed: [UInt8]) async throws -> UnifiedSpendingKey { guard let ffiBinaryKeyPtr = zcashlc_create_account( dbData.0, dbData.1, @@ -30,20 +55,13 @@ class ZcashRustBackend: ZcashRustBackendWelding { return ffiBinaryKeyPtr.pointee.unsafeToUnifiedSpendingKey(network: networkType) } - // swiftlint:disable function_parameter_count - static func createToAddress( - dbData: URL, + func createToAddress( usk: UnifiedSpendingKey, to address: String, value: Int64, - memo: MemoBytes?, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 { - let dbData = dbData.osStr() - - return usk.bytes.withUnsafeBufferPointer { uskPtr in + memo: MemoBytes? + ) async throws -> Int64 { + let result = usk.bytes.withUnsafeBufferPointer { uskPtr in zcashlc_create_to_address( dbData.0, dbData.1, @@ -52,55 +70,39 @@ class ZcashRustBackend: ZcashRustBackendWelding { [CChar](address.utf8CString), value, memo?.bytes, - spendParamsPath, - UInt(spendParamsPath.lengthOfBytes(using: .utf8)), - outputParamsPath, - UInt(outputParamsPath.lengthOfBytes(using: .utf8)), + spendParamsPath.0, + spendParamsPath.1, + outputParamsPath.0, + outputParamsPath.1, networkType.networkId, minimumConfirmations, useZIP317Fees ) } + + guard result > 0 else { + throw lastError() ?? .genericError(message: "No error message available") + } + + return result } - static func decryptAndStoreTransaction(dbData: URL, txBytes: [UInt8], minedHeight: Int32, networkType: NetworkType) async -> Bool { - let dbData = dbData.osStr() - return zcashlc_decrypt_and_store_transaction( + func decryptAndStoreTransaction(txBytes: [UInt8], minedHeight: Int32) async throws { + let result = zcashlc_decrypt_and_store_transaction( dbData.0, dbData.1, txBytes, UInt(txBytes.count), UInt32(minedHeight), networkType.networkId - ) != 0 - } + ) - static func deriveUnifiedSpendingKey( - from seed: [UInt8], - accountIndex: Int32, - networkType: NetworkType - ) throws -> UnifiedSpendingKey { - let binaryKeyPtr = seed.withUnsafeBufferPointer { seedBufferPtr in - return zcashlc_derive_spending_key( - seedBufferPtr.baseAddress, - UInt(seed.count), - accountIndex, - networkType.networkId - ) - } - - defer { zcashlc_free_binary_key(binaryKeyPtr) } - - guard let binaryKey = binaryKeyPtr?.pointee else { + guard result != 0 else { throw lastError() ?? .genericError(message: "No error message available") } - - return binaryKey.unsafeToUnifiedSpendingKey(network: networkType) } - static func getBalance(dbData: URL, account: Int32, networkType: NetworkType) async throws -> Int64 { - let dbData = dbData.osStr() - + func getBalance(account: Int32) async throws -> Int64 { let balance = zcashlc_get_balance(dbData.0, dbData.1, account, networkType.networkId) guard balance >= 0 else { @@ -110,13 +112,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { return balance } - static func getCurrentAddress( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> UnifiedAddress { - let dbData = dbData.osStr() - + func getCurrentAddress(account: Int32) async throws -> UnifiedAddress { guard let addressCStr = zcashlc_get_current_address( dbData.0, dbData.1, @@ -132,31 +128,25 @@ class ZcashRustBackend: ZcashRustBackendWelding { throw RustWeldingError.unableToDeriveKeys } - return UnifiedAddress(validatedEncoding: address) + return UnifiedAddress(validatedEncoding: address, networkType: networkType) } - static func getNearestRewindHeight( - dbData: URL, - height: Int32, - networkType: NetworkType - ) async -> Int32 { - let dbData = dbData.osStr() - - return zcashlc_get_nearest_rewind_height( + func getNearestRewindHeight(height: Int32) async throws -> Int32 { + let result = zcashlc_get_nearest_rewind_height( dbData.0, dbData.1, height, networkType.networkId ) + + guard result > 0 else { + throw lastError() ?? .genericError(message: "No error message available") + } + + return result } - static func getNextAvailableAddress( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> UnifiedAddress { - let dbData = dbData.osStr() - + func getNextAvailableAddress(account: Int32) async throws -> UnifiedAddress { guard let addressCStr = zcashlc_get_next_available_address( dbData.0, dbData.1, @@ -172,16 +162,10 @@ class ZcashRustBackend: ZcashRustBackendWelding { throw RustWeldingError.unableToDeriveKeys } - return UnifiedAddress(validatedEncoding: address) + return UnifiedAddress(validatedEncoding: address, networkType: networkType) } - static func getReceivedMemo( - dbData: URL, - idNote: Int64, - networkType: NetworkType - ) async -> Memo? { - let dbData = dbData.osStr() - + func getReceivedMemo(idNote: Int64) async -> Memo? { var contiguousMemoBytes = ContiguousArray(MemoBytes.empty().bytes) var success = false @@ -194,29 +178,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { return (try? MemoBytes(contiguousBytes: contiguousMemoBytes)).flatMap { try? $0.intoMemo() } } - static func getSaplingReceiver(for uAddr: UnifiedAddress) throws -> SaplingAddress { - guard let saplingCStr = zcashlc_get_sapling_receiver_for_unified_address( - [CChar](uAddr.encoding.utf8CString) - ) else { - throw KeyDerivationErrors.invalidUnifiedAddress - } - - defer { zcashlc_string_free(saplingCStr) } - - guard let saplingReceiverStr = String(validatingUTF8: saplingCStr) else { - throw KeyDerivationErrors.receiverNotFound - } - - return SaplingAddress(validatedEncoding: saplingReceiverStr) - } - - static func getSentMemo( - dbData: URL, - idNote: Int64, - networkType: NetworkType - ) async -> Memo? { - let dbData = dbData.osStr() - + func getSentMemo(idNote: Int64) async -> Memo? { var contiguousMemoBytes = ContiguousArray(MemoBytes.empty().bytes) var success = false @@ -229,16 +191,11 @@ class ZcashRustBackend: ZcashRustBackendWelding { return (try? MemoBytes(contiguousBytes: contiguousMemoBytes)).flatMap { try? $0.intoMemo() } } - static func getTransparentBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 { + func getTransparentBalance(account: Int32) async throws -> Int64 { guard account >= 0 else { throw RustWeldingError.invalidInput(message: "Account index must be non-negative") } - let dbData = dbData.osStr() let balance = zcashlc_get_total_transparent_balance_for_account( dbData.0, dbData.1, @@ -253,8 +210,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { return balance } - static func getVerifiedBalance(dbData: URL, account: Int32, networkType: NetworkType) async throws -> Int64 { - let dbData = dbData.osStr() + func getVerifiedBalance(account: Int32) async throws -> Int64 { let balance = zcashlc_get_verified_balance( dbData.0, dbData.1, @@ -270,17 +226,11 @@ class ZcashRustBackend: ZcashRustBackendWelding { return balance } - static func getVerifiedTransparentBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 { + func getVerifiedTransparentBalance(account: Int32) async throws -> Int64 { guard account >= 0 else { throw RustWeldingError.invalidInput(message: "`account` must be non-negative") } - let dbData = dbData.osStr() - let balance = zcashlc_get_verified_transparent_balance_for_account( dbData.0, dbData.1, @@ -299,8 +249,8 @@ class ZcashRustBackend: ZcashRustBackendWelding { return balance } - - static func lastError() -> RustWeldingError? { + + private nonisolated func lastError() -> RustWeldingError? { defer { zcashlc_clear_last_error() } guard let message = getLastError() else { @@ -316,43 +266,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { return RustWeldingError.genericError(message: message) } - static func getAddressMetadata(_ address: String) -> AddressMetadata? { - var networkId: UInt32 = 0 - var addrId: UInt32 = 0 - guard zcashlc_get_address_metadata( - [CChar](address.utf8CString), - &networkId, - &addrId - ) else { - return nil - } - - guard let network = NetworkType.forNetworkId(networkId), - let addrType = AddressType.forId(addrId) - else { - return nil - } - - return AddressMetadata(network: network, addrType: addrType) - } - - static func getTransparentReceiver(for uAddr: UnifiedAddress) throws -> TransparentAddress { - guard let transparentCStr = zcashlc_get_transparent_receiver_for_unified_address( - [CChar](uAddr.encoding.utf8CString) - ) else { - throw KeyDerivationErrors.invalidUnifiedAddress - } - - defer { zcashlc_string_free(transparentCStr) } - - guard let transparentReceiverStr = String(validatingUTF8: transparentCStr) else { - throw KeyDerivationErrors.receiverNotFound - } - - return TransparentAddress(validatedEncoding: transparentReceiverStr) - } - - static func getLastError() -> String? { + private nonisolated func getLastError() -> String? { let errorLen = zcashlc_last_error_length() if errorLen > 0 { let error = UnsafeMutablePointer.allocate(capacity: Int(errorLen)) @@ -364,8 +278,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { } } - static func initDataDb(dbData: URL, seed: [UInt8]?, networkType: NetworkType) async throws -> DbInitResult { - let dbData = dbData.osStr() + func initDataDb(seed: [UInt8]?) async throws -> DbInitResult { switch zcashlc_init_data_database(dbData.0, dbData.1, seed, UInt(seed?.count ?? 0), networkType.networkId) { case 0: // ok return DbInitResult.success @@ -375,74 +288,20 @@ class ZcashRustBackend: ZcashRustBackendWelding { throw throwDataDbError(lastError() ?? .genericError(message: "No error message found")) } } - - static func isValidSaplingAddress(_ address: String, networkType: NetworkType) -> Bool { - guard !address.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_shielded_address([CChar](address.utf8CString), networkType.networkId) - } - - static func isValidTransparentAddress(_ address: String, networkType: NetworkType) -> Bool { - guard !address.containsCStringNullBytesBeforeStringEnding() else { - return false - } - return zcashlc_is_valid_transparent_address([CChar](address.utf8CString), networkType.networkId) - } - - static func isValidSaplingExtendedFullViewingKey(_ key: String, networkType: NetworkType) -> Bool { - guard !key.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_viewing_key([CChar](key.utf8CString), networkType.networkId) - } - - static func isValidSaplingExtendedSpendingKey(_ key: String, networkType: NetworkType) -> Bool { - guard !key.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_sapling_extended_spending_key([CChar](key.utf8CString), networkType.networkId) - } - - static func isValidUnifiedAddress(_ address: String, networkType: NetworkType) -> Bool { - guard !address.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_unified_address([CChar](address.utf8CString), networkType.networkId) - } - - static func isValidUnifiedFullViewingKey(_ key: String, networkType: NetworkType) -> Bool { - guard !key.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_unified_full_viewing_key([CChar](key.utf8CString), networkType.networkId) - } - - static func initAccountsTable( - dbData: URL, - ufvks: [UnifiedFullViewingKey], - networkType: NetworkType - ) async throws { - let dbData = dbData.osStr() - + func initAccountsTable(ufvks: [UnifiedFullViewingKey]) async throws { var ffiUfvks: [FFIEncodedKey] = [] for ufvk in ufvks { guard !ufvk.encoding.containsCStringNullBytesBeforeStringEnding() else { throw RustWeldingError.invalidInput(message: "`UFVK` contains null bytes.") } - - guard self.isValidUnifiedFullViewingKey(ufvk.encoding, networkType: networkType) else { + + guard self.keyDeriving.isValidUnifiedFullViewingKey(ufvk.encoding) else { throw RustWeldingError.invalidInput(message: "UFVK is invalid.") } let ufvkCStr = [CChar](String(ufvk.encoding).utf8CString) - + let ufvkPtr = UnsafeMutablePointer.allocate(capacity: ufvkCStr.count) ufvkPtr.initialize(from: ufvkCStr, count: ufvkCStr.count) @@ -476,19 +335,15 @@ class ZcashRustBackend: ZcashRustBackendWelding { } } - static func initBlockMetadataDb(fsBlockDbRoot: URL) async throws -> Bool { - let blockDb = fsBlockDbRoot.osPathStr() - - let result = zcashlc_init_block_metadata_db(blockDb.0, blockDb.1) + func initBlockMetadataDb() async throws { + let result = zcashlc_init_block_metadata_db(fsBlockDbRoot.0, fsBlockDbRoot.1) guard result else { throw lastError() ?? .genericError(message: "`initAccountsTable` failed with unknown error") } - - return result } - static func writeBlocksMetadata(fsBlockDbRoot: URL, blocks: [ZcashCompactBlock]) async throws -> Bool { + func writeBlocksMetadata(blocks: [ZcashCompactBlock]) async throws { var ffiBlockMetaVec: [FFIBlockMeta] = [] for block in blocks { @@ -507,7 +362,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { hashPtr.deallocate() ffiBlockMetaVec.deallocateElements() } - return false + throw RustWeldingError.writeBlocksMetadataAllocationProblem } ffiBlockMetaVec.append( @@ -530,50 +385,35 @@ class ZcashRustBackend: ZcashRustBackendWelding { defer { ffiBlockMetaVec.deallocateElements() } - let result = try contiguousFFIBlocks.withContiguousMutableStorageIfAvailable { ptr in + try contiguousFFIBlocks.withContiguousMutableStorageIfAvailable { ptr in var meta = FFIBlocksMeta() meta.ptr = ptr.baseAddress meta.len = len fsBlocks.initialize(to: meta) - let fsDb = fsBlockDbRoot.osPathStr() - - let res = zcashlc_write_block_metadata(fsDb.0, fsDb.1, fsBlocks) + let res = zcashlc_write_block_metadata(fsBlockDbRoot.0, fsBlockDbRoot.1, fsBlocks) guard res else { throw lastError() ?? RustWeldingError.genericError(message: "failed to write block metadata") } - - return res } - - guard let value = result else { - return false - } - - return value } - // swiftlint:disable function_parameter_count - static func initBlocksTable( - dbData: URL, + func initBlocksTable( height: Int32, hash: String, time: UInt32, - saplingTree: String, - networkType: NetworkType + saplingTree: String ) async throws { - let dbData = dbData.osStr() - guard !hash.containsCStringNullBytesBeforeStringEnding() else { throw RustWeldingError.invalidInput(message: "`hash` contains null bytes.") } - + guard !saplingTree.containsCStringNullBytesBeforeStringEnding() else { throw RustWeldingError.invalidInput(message: "`saplingTree` contains null bytes.") } - + guard zcashlc_init_blocks_table( dbData.0, dbData.1, @@ -587,19 +427,11 @@ class ZcashRustBackend: ZcashRustBackendWelding { } } - static func latestCachedBlockHeight(fsBlockDbRoot: URL) async -> BlockHeight { - let fsBlockDb = fsBlockDbRoot.osPathStr() - - return BlockHeight(zcashlc_latest_cached_block_height(fsBlockDb.0, fsBlockDb.1)) + func latestCachedBlockHeight() async -> BlockHeight { + return BlockHeight(zcashlc_latest_cached_block_height(fsBlockDbRoot.0, fsBlockDbRoot.1)) } - static func listTransparentReceivers( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> [TransparentAddress] { - let dbData = dbData.osStr() - + func listTransparentReceivers(account: Int32) async throws -> [TransparentAddress] { guard let encodedKeysPtr = zcashlc_list_transparent_receivers( dbData.0, dbData.1, @@ -627,18 +459,14 @@ class ZcashRustBackend: ZcashRustBackendWelding { return addresses } - - static func putUnspentTransparentOutput( - dbData: URL, + + func putUnspentTransparentOutput( txid: [UInt8], index: Int, script: [UInt8], value: Int64, - height: BlockHeight, - networkType: NetworkType - ) async throws -> Bool { - let dbData = dbData.osStr() - + height: BlockHeight + ) async throws { guard zcashlc_put_utxo( dbData.0, dbData.1, @@ -653,48 +481,51 @@ class ZcashRustBackend: ZcashRustBackendWelding { ) else { throw lastError() ?? .genericError(message: "No error message available") } - - return true - } - - static func validateCombinedChain(fsBlockDbRoot: URL, dbData: URL, networkType: NetworkType, limit: UInt32 = 0) async -> Int32 { - let dbCache = fsBlockDbRoot.osPathStr() - let dbData = dbData.osStr() - return zcashlc_validate_combined_chain(dbCache.0, dbCache.1, dbData.0, dbData.1, networkType.networkId, limit) - } - - static func rewindToHeight(dbData: URL, height: Int32, networkType: NetworkType) -> Bool { - let dbData = dbData.osStr() - return zcashlc_rewind_to_height(dbData.0, dbData.1, height, networkType.networkId) } - static func rewindCacheToHeight( - fsBlockDbRoot: URL, - height: Int32 - ) -> Bool { - let fsBlockCache = fsBlockDbRoot.osPathStr() + func validateCombinedChain(limit: UInt32 = 0) async throws { + let result = zcashlc_validate_combined_chain(fsBlockDbRoot.0, fsBlockDbRoot.1, dbData.0, dbData.1, networkType.networkId, limit) - return zcashlc_rewind_fs_block_cache_to_height(fsBlockCache.0, fsBlockCache.1, height) + switch result { + case -1: + return + case 0: + throw RustWeldingError.chainValidationFailed(message: getLastError()) + default: + throw RustWeldingError.invalidChain(upperBound: result) + } } - static func scanBlocks(fsBlockDbRoot: URL, dbData: URL, limit: UInt32 = 0, networkType: NetworkType) async -> Bool { - let dbCache = fsBlockDbRoot.osPathStr() - let dbData = dbData.osStr() - return zcashlc_scan_blocks(dbCache.0, dbCache.1, dbData.0, dbData.1, limit, networkType.networkId) != 0 + func rewindToHeight(height: Int32) async throws { + let result = zcashlc_rewind_to_height(dbData.0, dbData.1, height, networkType.networkId) + + guard result else { + throw lastError() ?? .genericError(message: "No error message available") + } } - - static func shieldFunds( - dbData: URL, + + func rewindCacheToHeight(height: Int32) async throws { + let result = zcashlc_rewind_fs_block_cache_to_height(fsBlockDbRoot.0, fsBlockDbRoot.1, height) + + guard result else { + throw lastError() ?? .genericError(message: "No error message available") + } + } + + func scanBlocks(limit: UInt32 = 0) async throws { + let result = zcashlc_scan_blocks(fsBlockDbRoot.0, fsBlockDbRoot.1, dbData.0, dbData.1, limit, networkType.networkId) + + guard result != 0 else { + throw lastError() ?? .genericError(message: "No error message available") + } + } + + func shieldFunds( usk: UnifiedSpendingKey, memo: MemoBytes?, - shieldingThreshold: Zatoshi, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 { - let dbData = dbData.osStr() - - return usk.bytes.withUnsafeBufferPointer { uskBuffer in + shieldingThreshold: Zatoshi + ) async throws -> Int64 { + let result = usk.bytes.withUnsafeBufferPointer { uskBuffer in zcashlc_shield_funds( dbData.0, dbData.1, @@ -702,82 +533,36 @@ class ZcashRustBackend: ZcashRustBackendWelding { UInt(usk.bytes.count), memo?.bytes, UInt64(shieldingThreshold.amount), - spendParamsPath, - UInt(spendParamsPath.lengthOfBytes(using: .utf8)), - outputParamsPath, - UInt(outputParamsPath.lengthOfBytes(using: .utf8)), + spendParamsPath.0, + spendParamsPath.1, + outputParamsPath.0, + outputParamsPath.1, networkType.networkId, minimumConfirmations, useZIP317Fees ) } - } - - static func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey, networkType: NetworkType) throws -> UnifiedFullViewingKey { - let extfvk = try spendingKey.bytes.withUnsafeBufferPointer { uskBufferPtr -> UnsafeMutablePointer in - guard let extfvk = zcashlc_spending_key_to_full_viewing_key( - uskBufferPtr.baseAddress, - UInt(spendingKey.bytes.count), - networkType.networkId - ) else { - throw lastError() ?? .genericError(message: "No error message available") - } - return extfvk + guard result > 0 else { + throw lastError() ?? .genericError(message: "No error message available") } - defer { zcashlc_string_free(extfvk) } - - guard let derived = String(validatingUTF8: extfvk) else { - throw RustWeldingError.unableToDeriveKeys - } - - return UnifiedFullViewingKey(validatedEncoding: derived, account: spendingKey.account) + return result } - static func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] { - guard !address.containsCStringNullBytesBeforeStringEnding() else { - throw RustWeldingError.invalidInput(message: "`address` contains null bytes.") - } - - var len = UInt(0) - - guard let typecodesPointer = zcashlc_get_typecodes_for_unified_address_receivers( - [CChar](address.utf8CString), - &len - ), len > 0 - else { - throw RustWeldingError.malformedStringInput - } - - var typecodes: [UInt32] = [] - - for typecodeIndex in 0 ..< Int(len) { - let pointer = typecodesPointer.advanced(by: typecodeIndex) - - typecodes.append(pointer.pointee) - } - - defer { - zcashlc_free_typecodes(typecodesPointer, len) - } - - return typecodes - } - - static func consensusBranchIdFor(height: Int32, networkType: NetworkType) throws -> Int32 { + nonisolated func consensusBranchIdFor(height: Int32) throws -> Int32 { let branchId = zcashlc_branch_id_for_height(height, networkType.networkId) - + guard branchId != -1 else { throw RustWeldingError.noConsensusBranchId(height: height) } - + return branchId } } private extension ZcashRustBackend { - static func throwDataDbError(_ error: RustWeldingError) -> Error { + func throwDataDbError(_ error: RustWeldingError) -> Error { if case RustWeldingError.genericError(let message) = error, message.contains("is not empty") { return RustWeldingError.dataDbNotEmpty } @@ -785,7 +570,7 @@ private extension ZcashRustBackend { return RustWeldingError.dataDbInitFailed(message: error.localizedDescription) } - static func throwBalanceError(account: Int32, _ error: RustWeldingError?, fallbackMessage: String) -> Error { + func throwBalanceError(account: Int32, _ error: RustWeldingError?, fallbackMessage: String) -> Error { guard let balanceError = error else { return RustWeldingError.genericError(message: fallbackMessage) } @@ -865,6 +650,15 @@ extension RustWeldingError: LocalizedError { return "`.unableToDeriveKeys` the requested keys could not be derived from the source provided" case let .getBalanceError(account, error): return "`.getBalanceError` could not retrieve balance from account: \(account), error:\(error)" + case let .invalidChain(upperBound: upperBound): + return "`.validateCombinedChain` failed to validate chain. Upper bound: \(upperBound)." + case let .chainValidationFailed(message): + return """ + `.validateCombinedChain` failed to validate chain because of error unrelated to chain validity. \ + Message: \(String(describing: message)) + """ + case .writeBlocksMetadataAllocationProblem: + return "`.writeBlocksMetadata` failed to allocate memory on Swift side necessary to write blocks metadata to db." } } } diff --git a/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift b/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift index e39da64a..fcb100ae 100644 --- a/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift +++ b/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift @@ -19,6 +19,13 @@ enum RustWeldingError: Error { case getBalanceError(Int, Error) case invalidInput(message: String) case invalidRewind(suggestedHeight: Int32) + /// Thrown when `upperBound` if the combined chain is invalid. `upperBound` is the height of the highest invalid block (on the assumption that + /// the highest block in the cache database is correct). + case invalidChain(upperBound: Int32) + /// Thrown if there was an error during validation unrelated to chain validity. + case chainValidationFailed(message: String?) + /// Thrown if there was problem with memory allocation on the Swift side while trying to write blocks metadata to DB. + case writeBlocksMetadataAllocationProblem } enum ZcashRustBackendWeldingConstants { @@ -31,6 +38,7 @@ public enum DbInitResult { case seedRequired } +// sourcery: mockActor protocol ZcashRustBackendWelding { /// Adds the next available account-level spend authority, given the current set of [ZIP 316] /// account identifiers known, to the wallet database. @@ -46,83 +54,34 @@ protocol ZcashRustBackendWelding { /// By convention, wallets should only allow a new account to be generated after funds /// have been received by the currently-available account (in order to enable /// automated account recovery). - /// - parameter dbData: location of the data db /// - parameter seed: byte array of the zip32 seed - /// - parameter networkType: network type of this key /// - Returns: The `UnifiedSpendingKey` structs for the number of accounts created - /// - static func createAccount( - dbData: URL, - seed: [UInt8], - networkType: NetworkType - ) async throws -> UnifiedSpendingKey + func createAccount(seed: [UInt8]) async throws -> UnifiedSpendingKey /// Creates a transaction to the given address from the given account - /// - Parameter dbData: URL for the Data DB /// - Parameter usk: `UnifiedSpendingKey` for the account that controls the funds to be spent. /// - Parameter to: recipient address /// - Parameter value: transaction amount in Zatoshi /// - Parameter memo: the `MemoBytes` for this transaction. pass `nil` when sending to transparent receivers - /// - Parameter spendParamsPath: path escaped String for the filesystem locations where the spend parameters are located - /// - Parameter outputParamsPath: path escaped String for the filesystem locations where the output parameters are located - /// - Parameter networkType: network type of this key - // swiftlint:disable:next function_parameter_count - static func createToAddress( - dbData: URL, + func createToAddress( usk: UnifiedSpendingKey, to address: String, value: Int64, - memo: MemoBytes?, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 // swiftlint:disable function_parameter_count - - /// Scans a transaction for any information that can be decrypted by the accounts in the - /// wallet, and saves it to the wallet. - /// - parameter dbData: location of the data db file - /// - parameter tx: the transaction to decrypt - /// - parameter minedHeight: height on which this transaction was mined. this is used to fetch the consensus branch ID. - /// - parameter networkType: network type of this key - /// returns false if fails to decrypt. - static func decryptAndStoreTransaction( - dbData: URL, - txBytes: [UInt8], - minedHeight: Int32, - networkType: NetworkType - ) async -> Bool - - /// Derives and returns a unified spending key from the given seed for the given account ID. - /// Returns the binary encoding of the spending key. The caller should manage the memory of (and store, if necessary) the returned spending key in a secure fashion. - /// - Parameter seed: a Byte Array with the seed - /// - Parameter accountIndex:account index that the key can spend from - /// - Parameter networkType: network type of this key - /// - Throws `.unableToDerive` when there's an error - static func deriveUnifiedSpendingKey( - from seed: [UInt8], - accountIndex: Int32, - networkType: NetworkType - ) throws -> UnifiedSpendingKey - - /// get the (unverified) balance from the given account - /// - parameter dbData: location of the data db - /// - parameter account: index of the given account - /// - parameter networkType: network type of this key - static func getBalance( - dbData: URL, - account: Int32, - networkType: NetworkType + memo: MemoBytes? ) async throws -> Int64 - /// Returns the most-recently-generated unified payment address for the specified account. - /// - parameter dbData: location of the data db + /// Scans a transaction for any information that can be decrypted by the accounts in the wallet, and saves it to the wallet. + /// - parameter tx: the transaction to decrypt + /// - parameter minedHeight: height on which this transaction was mined. this is used to fetch the consensus branch ID. + func decryptAndStoreTransaction(txBytes: [UInt8], minedHeight: Int32) async throws + + /// Get the (unverified) balance from the given account. /// - parameter account: index of the given account - /// - parameter networkType: network type of this key - static func getCurrentAddress( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> UnifiedAddress + func getBalance(account: Int32) async throws -> Int64 + + /// Returns the most-recently-generated unified payment address for the specified account. + /// - parameter account: index of the given account + func getCurrentAddress(account: Int32) async throws -> UnifiedAddress /// Wallets might need to be rewound because of a reorg, or by user request. /// There are times where the wallet could get out of sync for many reasons and @@ -131,195 +90,64 @@ protocol ZcashRustBackendWelding { /// of sapling witnesses older than 100 blocks. So in order to reconstruct the witness /// tree that allows to spend notes from the given wallet the rewind can't be more than /// 100 blocks or back to the oldest unspent note that this wallet contains. - /// - parameter dbData: location of the data db file /// - parameter height: height you would like to rewind to. - /// - parameter networkType: network type of this key] /// - Returns: the blockheight of the nearest rewind height. - /// - static func getNearestRewindHeight( - dbData: URL, - height: Int32, - networkType: NetworkType - ) async -> Int32 + func getNearestRewindHeight(height: Int32) async throws -> Int32 /// Returns a newly-generated unified payment address for the specified account, with the next available diversifier. - /// - parameter dbData: location of the data db /// - parameter account: index of the given account - /// - parameter networkType: network type of this key - static func getNextAvailableAddress( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> UnifiedAddress + func getNextAvailableAddress(account: Int32) async throws -> UnifiedAddress - /// get received memo from note - /// - parameter dbData: location of the data db file + /// Get received memo from note. /// - parameter idNote: note_id of note where the memo is located - /// - parameter networkType: network type of this key - static func getReceivedMemo( - dbData: URL, - idNote: Int64, - networkType: NetworkType - ) async -> Memo? + func getReceivedMemo(idNote: Int64) async -> Memo? - /// Returns the Sapling receiver within the given Unified Address, if any. - /// - Parameter uAddr: a `UnifiedAddress` - /// - Returns a `SaplingAddress` if any - /// - Throws `receiverNotFound` when the receiver is not found. `invalidUnifiedAddress` if the UA provided is not valid - static func getSaplingReceiver(for uAddr: UnifiedAddress) throws -> SaplingAddress - - /// get sent memo from note - /// - parameter dbData: location of the data db file + /// Get sent memo from note. /// - parameter idNote: note_id of note where the memo is located - /// - parameter networkType: network type of this key /// - Returns: a `Memo` if any - static func getSentMemo( - dbData: URL, - idNote: Int64, - networkType: NetworkType - ) async -> Memo? + func getSentMemo(idNote: Int64) async -> Memo? /// Get the verified cached transparent balance for the given address - /// - parameter dbData: location of the data db file /// - parameter account; the account index to query - /// - parameter networkType: network type of this key - static func getTransparentBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 + func getTransparentBalance(account: Int32) async throws -> Int64 - /// Returns the transparent receiver within the given Unified Address, if any. - /// - parameter uAddr: a `UnifiedAddress` - /// - Returns a `TransparentAddress` if any - /// - Throws `receiverNotFound` when the receiver is not found. `invalidUnifiedAddress` if the UA provided is not valid - static func getTransparentReceiver(for uAddr: UnifiedAddress) throws -> TransparentAddress - - /// gets the latest error if available. Clear the existing error - /// - Returns a `RustWeldingError` if exists - static func lastError() -> RustWeldingError? - - /// gets the latest error message from librustzcash. Does not clear existing error - static func getLastError() -> String? - - /// initialize the accounts table from a set of unified full viewing keys - /// - Note: this function should only be used when restoring an existing seed phrase. - /// when creating a new wallet, use `createAccount()` instead - /// - Parameter dbData: location of the data db + /// Initialize the accounts table from a set of unified full viewing keys. + /// - Note: this function should only be used when restoring an existing seed phrase. when creating a new wallet, use `createAccount()` instead. /// - Parameter ufvks: an array of UnifiedFullViewingKeys - /// - Parameter networkType: network type of this key - static func initAccountsTable( - dbData: URL, - ufvks: [UnifiedFullViewingKey], - networkType: NetworkType - ) async throws + func initAccountsTable(ufvks: [UnifiedFullViewingKey]) async throws - /// initializes the data db. This will performs any migrations needed on the sqlite file + /// Initializes the data db. This will performs any migrations needed on the sqlite file /// provided. Some migrations might need that callers provide the seed bytes. - /// - Parameter dbData: location of the data db sql file /// - Parameter seed: ZIP-32 compliant seed bytes for this wallet - /// - Parameter networkType: network type of this key /// - Returns: `DbInitResult.success` if the dataDb was initialized successfully /// or `DbInitResult.seedRequired` if the operation requires the seed to be passed /// in order to be completed successfully. - static func initDataDb( - dbData: URL, - seed: [UInt8]?, - networkType: NetworkType - ) async throws -> DbInitResult + func initDataDb(seed: [UInt8]?) async throws -> DbInitResult - /// Returns the network and address type for the given Zcash address string, - /// if the string represents a valid Zcash address. - static func getAddressMetadata(_ address: String) -> AddressMetadata? - - /// Validates the if the given string is a valid Sapling Address - /// - Parameter address: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: true when the address is valid. Returns false in any other case - /// - Throws: Error when the provided address belongs to another network - static func isValidSaplingAddress(_ address: String, networkType: NetworkType) -> Bool - - /// Validates the if the given string is a valid Sapling Extended Full Viewing Key - /// - Parameter key: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: `true` when the Sapling Extended Full Viewing Key is valid. `false` in any other case - /// - Throws: Error when there's another problem not related to validity of the string in question - static func isValidSaplingExtendedFullViewingKey(_ key: String, networkType: NetworkType) -> Bool - - /// Validates the if the given string is a valid Sapling Extended Spending Key - /// - Returns: `true` when the Sapling Extended Spending Key is valid, false in any other case. - /// - Throws: Error when the key is semantically valid but it belongs to another network - /// - parameter key: String encoded Extended Spending Key - /// - parameter networkType: `NetworkType` signaling testnet or mainnet - static func isValidSaplingExtendedSpendingKey(_ key: String, networkType: NetworkType) -> Bool - - /// Validates the if the given string is a valid Transparent Address - /// - Parameter address: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: true when the address is valid and transparent. false in any other case - /// - Throws: Error when the provided address belongs to another network - static func isValidTransparentAddress(_ address: String, networkType: NetworkType) -> Bool - - /// validates whether a string encoded address is a valid Unified Address. - /// - Parameter address: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: true when the address is valid and transparent. false in any other case - /// - Throws: Error when the provided address belongs to another network - static func isValidUnifiedAddress(_ address: String, networkType: NetworkType) -> Bool - - /// verifies that the given string-encoded `UnifiedFullViewingKey` is valid. - /// - Parameter ufvk: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: true when the encoded string is a valid UFVK. false in any other case - /// - Throws: Error when there's another problem not related to validity of the string in question - static func isValidUnifiedFullViewingKey(_ ufvk: String, networkType: NetworkType) -> Bool - - /// initialize the blocks table from a given checkpoint (heigh, hash, time, saplingTree and networkType) - /// - parameter dbData: location of the data db + /// Initialize the blocks table from a given checkpoint (heigh, hash, time, saplingTree and networkType). /// - parameter height: represents the block height of the given checkpoint /// - parameter hash: hash of the merkle tree /// - parameter time: in milliseconds from reference /// - parameter saplingTree: hash of the sapling tree - /// - parameter networkType: `NetworkType` signaling testnet or mainnet - static func initBlocksTable( - dbData: URL, + func initBlocksTable( height: Int32, hash: String, time: UInt32, - saplingTree: String, - networkType: NetworkType - ) async throws // swiftlint:disable function_parameter_count + saplingTree: String + ) async throws /// Returns a list of the transparent receivers for the diversified unified addresses that have /// been allocated for the provided account. - /// - parameter dbData: location of the data db /// - parameter account: index of the given account - /// - parameter networkType: the network type - static func listTransparentReceivers( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> [TransparentAddress] + func listTransparentReceivers(account: Int32) async throws -> [TransparentAddress] - /// get the verified balance from the given account - /// - parameter dbData: location of the data db + /// Get the verified balance from the given account /// - parameter account: index of the given account - /// - parameter networkType: the network type - static func getVerifiedBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 + func getVerifiedBalance(account: Int32) async throws -> Int64 /// Get the verified cached transparent balance for the given account - /// - parameter dbData: location of the data db /// - parameter account: account index to query the balance for. - /// - parameter networkType: the network type - static func getVerifiedTransparentBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 + func getVerifiedTransparentBalance(account: Int32) async throws -> Int64 /// Checks that the scanned blocks in the data database, when combined with the recent /// `CompactBlock`s in the cache database, form a valid chain. @@ -333,39 +161,22 @@ protocol ZcashRustBackendWelding { /// - parameter dbData: location of the data db file /// - parameter networkType: the network type /// - parameter limit: a limit to validate a fixed number of blocks instead of the whole cache. - /// - Returns: - /// - `-1` if the combined chain is valid. - /// - `upper_bound` if the combined chain is invalid. - /// - `upper_bound` is the height of the highest invalid block (on the assumption that the highest block in the cache database is correct). - /// - `0` if there was an error during validation unrelated to chain validity. + /// - Throws: + /// - `RustWeldingError.chainValidationFailed` if there was an error during validation unrelated to chain validity. + /// - `RustWeldingError.invalidChain(upperBound)` if the combined chain is invalid. `upperBound` is the height of the highest invalid block + /// (on the assumption that the highest block in the cache database is correct). + /// /// - Important: This function does not mutate either of the databases. - static func validateCombinedChain( - fsBlockDbRoot: URL, - dbData: URL, - networkType: NetworkType, - limit: UInt32 - ) async -> Int32 + func validateCombinedChain(limit: UInt32) async throws /// Resets the state of the database to only contain block and transaction information up to the given height. clears up all derived data as well - /// - parameter dbData: `URL` pointing to the filesystem root directory where the fsBlock cache is. - /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename - /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. /// - parameter height: height to rewind to. - static func rewindToHeight( - dbData: URL, - height: Int32, - networkType: NetworkType - ) async -> Bool + func rewindToHeight(height: Int32) async throws /// Resets the state of the FsBlock database to only contain block and transaction information up to the given height. /// - Note: this does not delete the files. Only rolls back the database. - /// - parameter fsBlockDbRoot: location of the data db file - /// - parameter height: height to rewind to. DON'T PASS ARBITRARY HEIGHT. Use getNearestRewindHeight when unsure - /// - parameter networkType: the network type - static func rewindCacheToHeight( - fsBlockDbRoot: URL, - height: Int32 - ) async -> Bool + /// - parameter height: height to rewind to. DON'T PASS ARBITRARY HEIGHT. Use `getNearestRewindHeight` when unsure + func rewindCacheToHeight(height: Int32) async throws /// Scans new blocks added to the cache for any transactions received by the tracked /// accounts. @@ -379,101 +190,48 @@ protocol ZcashRustBackendWelding { /// Scanned blocks are required to be height-sequential. If a block is missing from the /// cache, an error will be signalled. /// - /// - parameter fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. - /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename - /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. - /// - parameter dbData: location of the data db sqlite file /// - parameter limit: scan up to limit blocks. pass 0 to set no limit. - /// - parameter networkType: the network type - /// returns false if fails to scan. - static func scanBlocks( - fsBlockDbRoot: URL, - dbData: URL, - limit: UInt32, - networkType: NetworkType - ) async -> Bool + func scanBlocks(limit: UInt32) async throws /// Upserts a UTXO into the data db database - /// - parameter dbData: location of the data db file /// - parameter txid: the txid bytes for the UTXO /// - parameter index: the index of the UTXO /// - parameter script: the script of the UTXO /// - parameter value: the value of the UTXO /// - parameter height: the mined height for the UTXO - /// - parameter networkType: the network type - /// - Returns: true if the operation succeded or false otherwise - static func putUnspentTransparentOutput( - dbData: URL, + func putUnspentTransparentOutput( txid: [UInt8], index: Int, script: [UInt8], value: Int64, - height: BlockHeight, - networkType: NetworkType - ) async throws -> Bool + height: BlockHeight + ) async throws /// Creates a transaction to shield all found UTXOs in data db for the account the provided `UnifiedSpendingKey` has spend authority for. - /// - Parameter dbData: URL for the Data DB /// - Parameter usk: `UnifiedSpendingKey` that spend transparent funds and where the funds will be shielded to. /// - Parameter memo: the `Memo` for this transaction - /// - Parameter spendParamsPath: path escaped String for the filesystem locations where the spend parameters are located - /// - Parameter outputParamsPath: path escaped String for the filesystem locations where the output parameters are located - /// - Parameter networkType: the network type - static func shieldFunds( - dbData: URL, + func shieldFunds( usk: UnifiedSpendingKey, memo: MemoBytes?, - shieldingThreshold: Zatoshi, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 - - /// Obtains the available receiver typecodes for the given String encoded Unified Address - /// - Parameter address: public key represented as a String - /// - Returns the `[UInt32]` that compose the given UA - /// - Throws `RustWeldingError.invalidInput(message: String)` when the UA is either invalid or malformed - static func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] + shieldingThreshold: Zatoshi + ) async throws -> Int64 /// Gets the consensus branch id for the given height /// - Parameter height: the height you what to know the branch id for - /// - Parameter networkType: the network type - static func consensusBranchIdFor( - height: Int32, - networkType: NetworkType - ) throws -> Int32 + func consensusBranchIdFor(height: Int32) throws -> Int32 - /// Derives a `UnifiedFullViewingKey` from a `UnifiedSpendingKey` - /// - Parameter spendingKey: the `UnifiedSpendingKey` to derive from - /// - Parameter networkType: the network type - /// - Throws: `RustWeldingError.unableToDeriveKeys` if the SDK couldn't derive the UFVK. - /// - Returns: the derived `UnifiedFullViewingKey` - static func deriveUnifiedFullViewingKey( - from spendingKey: UnifiedSpendingKey, - networkType: NetworkType - ) throws -> UnifiedFullViewingKey - - /// initializes Filesystem based block cache - /// - Parameter fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. - /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename - /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. - /// - returns `true` when successful, `false` when fails but no throwing information was found + /// Initializes Filesystem based block cache /// - throws `RustWeldingError` when fails to initialize - static func initBlockMetadataDb(fsBlockDbRoot: URL) async throws -> Bool + func initBlockMetadataDb() async throws /// Write compact block metadata to a database known to the Rust layer - /// - Parameter fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. - /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename - /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. - /// - Parameter blocks: The `ZcashCompactBlock`s that are going to be marked as stored by the - /// metadata Db. - /// - Returns `true` if the operation was successful, `false` otherwise. - static func writeBlocksMetadata(fsBlockDbRoot: URL, blocks: [ZcashCompactBlock]) async throws -> Bool + /// - Parameter blocks: The `ZcashCompactBlock`s that are going to be marked as stored by the metadata Db. + func writeBlocksMetadata(blocks: [ZcashCompactBlock]) async throws /// Gets the latest block height stored in the filesystem based cache. /// - Parameter fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. /// - Returns `BlockHeight` of the latest cached block or `.empty` if no blocks are stored. - static func latestCachedBlockHeight(fsBlockDbRoot: URL) async -> BlockHeight + func latestCachedBlockHeight() async -> BlockHeight } diff --git a/Sources/ZcashLightClientKit/Synchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer.swift index fed08df2..bbbcc5b0 100644 --- a/Sources/ZcashLightClientKit/Synchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer.swift @@ -230,24 +230,34 @@ public protocol Synchronizer: AnyObject { func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository /// Get all memos for `transaction`. + /// + // sourcery: mockedName="getMemosForClearedTransaction" func getMemos(for transaction: ZcashTransaction.Overview) async throws -> [Memo] /// Get all memos for `receivedTransaction`. + /// + // sourcery: mockedName="getMemosForReceivedTransaction" func getMemos(for receivedTransaction: ZcashTransaction.Received) async throws -> [Memo] /// Get all memos for `sentTransaction`. + /// + // sourcery: mockedName="getMemosForSentTransaction" func getMemos(for sentTransaction: ZcashTransaction.Sent) async throws -> [Memo] /// Attempt to get recipients from a Transaction Overview. /// - parameter transaction: A transaction overview /// - returns the recipients or an empty array if no recipients are found on this transaction because it's not an outgoing /// transaction + /// + // sourcery: mockedName="getRecipientsForClearedTransaction" func getRecipients(for transaction: ZcashTransaction.Overview) async -> [TransactionRecipient] /// Get the recipients for the given a sent transaction /// - parameter transaction: A transaction overview /// - returns the recipients or an empty array if no recipients are found on this transaction because it's not an outgoing /// transaction + /// + // sourcery: mockedName="getRecipientsForSentTransaction" func getRecipients(for transaction: ZcashTransaction.Sent) async -> [TransactionRecipient] /// Returns a list of confirmed transactions that preceed the given transaction with a limit count. diff --git a/Sources/ZcashLightClientKit/Synchronizer/ClosureSDKSynchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer/ClosureSDKSynchronizer.swift index 915c9719..d5fff219 100644 --- a/Sources/ZcashLightClientKit/Synchronizer/ClosureSDKSynchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer/ClosureSDKSynchronizer.swift @@ -37,37 +37,37 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { walletBirthday: BlockHeight, completion: @escaping (Result) -> Void ) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { return try await self.synchronizer.prepare(with: seed, viewingKeys: viewingKeys, walletBirthday: walletBirthday) } } public func start(retry: Bool, completion: @escaping (Error?) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.start(retry: retry) } } public func stop(completion: @escaping () -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.stop() } } public func getSaplingAddress(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getSaplingAddress(accountIndex: accountIndex) } } public func getUnifiedAddress(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getUnifiedAddress(accountIndex: accountIndex) } } public func getTransparentAddress(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getTransparentAddress(accountIndex: accountIndex) } } @@ -79,7 +79,7 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { memo: Memo?, completion: @escaping (Result) -> Void ) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.sendToAddress(spendingKey: spendingKey, zatoshi: zatoshi, toAddress: toAddress, memo: memo) } } @@ -90,37 +90,37 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { shieldingThreshold: Zatoshi, completion: @escaping (Result) -> Void ) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.shieldFunds(spendingKey: spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) } } public func cancelSpend(transaction: PendingTransactionEntity, completion: @escaping (Bool) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.cancelSpend(transaction: transaction) } } public func pendingTransactions(completion: @escaping ([PendingTransactionEntity]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.pendingTransactions } } public func clearedTransactions(completion: @escaping ([ZcashTransaction.Overview]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.clearedTransactions } } public func sentTranscations(completion: @escaping ([ZcashTransaction.Sent]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.sentTransactions } } public func receivedTransactions(completion: @escaping ([ZcashTransaction.Received]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.receivedTransactions } } @@ -128,31 +128,31 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { public func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository { synchronizer.paginatedTransactions(of: kind) } public func getMemos(for transaction: ZcashTransaction.Overview, completion: @escaping (Result<[Memo], Error>) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getMemos(for: transaction) } } public func getMemos(for receivedTransaction: ZcashTransaction.Received, completion: @escaping (Result<[Memo], Error>) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getMemos(for: receivedTransaction) } } public func getMemos(for sentTransaction: ZcashTransaction.Sent, completion: @escaping (Result<[Memo], Error>) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getMemos(for: sentTransaction) } } public func getRecipients(for transaction: ZcashTransaction.Overview, completion: @escaping ([TransactionRecipient]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.getRecipients(for: transaction) } } public func getRecipients(for transaction: ZcashTransaction.Sent, completion: @escaping ([TransactionRecipient]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.getRecipients(for: transaction) } } @@ -162,37 +162,37 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { limit: Int, completion: @escaping (Result<[ZcashTransaction.Overview], Error>) -> Void ) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.allConfirmedTransactions(from: transaction, limit: limit) } } public func latestHeight(completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.latestHeight() } } public func refreshUTXOs(address: TransparentAddress, from height: BlockHeight, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.refreshUTXOs(address: address, from: height) } } public func getTransparentBalance(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getTransparentBalance(accountIndex: accountIndex) } } public func getShieldedBalance(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getShieldedBalance(accountIndex: accountIndex) } } public func getShieldedVerifiedBalance(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getShieldedVerifiedBalance(accountIndex: accountIndex) } } @@ -204,41 +204,3 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { public func rewind(_ policy: RewindPolicy) -> CompletablePublisher { synchronizer.rewind(policy) } public func wipe() -> CompletablePublisher { synchronizer.wipe() } } - -extension ClosureSDKSynchronizer { - private func executeAction(_ completion: @escaping () -> Void, action: @escaping () async -> Void) { - Task { - await action() - completion() - } - } - - private func executeAction(_ completion: @escaping (R) -> Void, action: @escaping () async -> R) { - Task { - let result = await action() - completion(result) - } - } - - private func executeThrowingAction(_ completion: @escaping (Error?) -> Void, action: @escaping () async throws -> Void) { - Task { - do { - try await action() - completion(nil) - } catch { - completion(error) - } - } - } - - private func executeThrowingAction(_ completion: @escaping (Result) -> Void, action: @escaping () async throws -> R) { - Task { - do { - let result = try await action() - completion(.success(result)) - } catch { - completion(.failure(error)) - } - } - } -} diff --git a/Sources/ZcashLightClientKit/Synchronizer/CombineSDKSynchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer/CombineSDKSynchronizer.swift index f725050e..ea531dce 100644 --- a/Sources/ZcashLightClientKit/Synchronizer/CombineSDKSynchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer/CombineSDKSynchronizer.swift @@ -36,37 +36,37 @@ extension CombineSDKSynchronizer: CombineSynchronizer { viewingKeys: [UnifiedFullViewingKey], walletBirthday: BlockHeight ) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { return try await self.synchronizer.prepare(with: seed, viewingKeys: viewingKeys, walletBirthday: walletBirthday) } } public func start(retry: Bool) -> CompletablePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.start(retry: retry) } } public func stop() -> CompletablePublisher { - return executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.stop() } } public func getSaplingAddress(accountIndex: Int) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getSaplingAddress(accountIndex: accountIndex) } } public func getUnifiedAddress(accountIndex: Int) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getUnifiedAddress(accountIndex: accountIndex) } } public func getTransparentAddress(accountIndex: Int) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getTransparentAddress(accountIndex: accountIndex) } } @@ -77,7 +77,7 @@ extension CombineSDKSynchronizer: CombineSynchronizer { toAddress: Recipient, memo: Memo? ) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.sendToAddress(spendingKey: spendingKey, zatoshi: zatoshi, toAddress: toAddress, memo: memo) } } @@ -87,37 +87,37 @@ extension CombineSDKSynchronizer: CombineSynchronizer { memo: Memo, shieldingThreshold: Zatoshi ) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.shieldFunds(spendingKey: spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) } } public func cancelSpend(transaction: PendingTransactionEntity) -> SinglePublisher { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.cancelSpend(transaction: transaction) } } public var pendingTransactions: SinglePublisher<[PendingTransactionEntity], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.pendingTransactions } } public var clearedTransactions: SinglePublisher<[ZcashTransaction.Overview], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.clearedTransactions } } public var sentTransactions: SinglePublisher<[ZcashTransaction.Sent], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.sentTransactions } } public var receivedTransactions: SinglePublisher<[ZcashTransaction.Received], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.receivedTransactions } } @@ -125,67 +125,67 @@ extension CombineSDKSynchronizer: CombineSynchronizer { public func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository { synchronizer.paginatedTransactions(of: kind) } public func getMemos(for transaction: ZcashTransaction.Overview) -> SinglePublisher<[Memo], Error> { - executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getMemos(for: transaction) } } public func getMemos(for receivedTransaction: ZcashTransaction.Received) -> SinglePublisher<[Memo], Error> { - executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getMemos(for: receivedTransaction) } } public func getMemos(for sentTransaction: ZcashTransaction.Sent) -> SinglePublisher<[Memo], Error> { - executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getMemos(for: sentTransaction) } } public func getRecipients(for transaction: ZcashTransaction.Overview) -> SinglePublisher<[TransactionRecipient], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.getRecipients(for: transaction) } } public func getRecipients(for transaction: ZcashTransaction.Sent) -> SinglePublisher<[TransactionRecipient], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.getRecipients(for: transaction) } } public func allConfirmedTransactions(from transaction: ZcashTransaction.Overview, limit: Int) -> SinglePublisher<[ZcashTransaction.Overview], Error> { - executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.allConfirmedTransactions(from: transaction, limit: limit) } } public func latestHeight() -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.latestHeight() } } public func refreshUTXOs(address: TransparentAddress, from height: BlockHeight) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.refreshUTXOs(address: address, from: height) } } public func getTransparentBalance(accountIndex: Int) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getTransparentBalance(accountIndex: accountIndex) } } public func getShieldedBalance(accountIndex: Int = 0) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await synchronizer.getShieldedBalance(accountIndex: accountIndex) } } public func getShieldedVerifiedBalance(accountIndex: Int = 0) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await synchronizer.getShieldedVerifiedBalance(accountIndex: accountIndex) } } @@ -193,53 +193,3 @@ extension CombineSDKSynchronizer: CombineSynchronizer { public func rewind(_ policy: RewindPolicy) -> CompletablePublisher { synchronizer.rewind(policy) } public func wipe() -> CompletablePublisher { synchronizer.wipe() } } - -extension CombineSDKSynchronizer { - private func executeAction(action: @escaping () async -> Void) -> CompletablePublisher { - let subject = PassthroughSubject() - Task { - await action() - subject.send(completion: .finished) - } - return subject.eraseToAnyPublisher() - } - - private func executeAction(action: @escaping () async -> R) -> SinglePublisher { - let subject = PassthroughSubject() - Task { - let result = await action() - subject.send(result) - subject.send(completion: .finished) - } - return subject.eraseToAnyPublisher() - } - - private func executeThrowingAction(action: @escaping () async throws -> Void) -> CompletablePublisher { - let subject = PassthroughSubject() - Task { - do { - try await action() - subject.send(completion: .finished) - } catch { - subject.send(completion: .failure(error)) - } - } - - return subject.eraseToAnyPublisher() - } - - private func executeThrowingAction(action: @escaping () async throws -> R) -> SinglePublisher { - let subject = PassthroughSubject() - Task { - do { - let result = try await action() - subject.send(result) - subject.send(completion: .finished) - } catch { - subject.send(completion: .failure(error)) - } - } - - return subject.eraseToAnyPublisher() - } -} diff --git a/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift index a721f7b7..6803a693 100644 --- a/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift @@ -458,21 +458,13 @@ public class SDKSynchronizer: Synchronizer { } public func getShieldedBalance(accountIndex: Int = 0) async throws -> Zatoshi { - let balance = try await initializer.rustBackend.getBalance( - dbData: initializer.dataDbURL, - account: Int32(accountIndex), - networkType: network.networkType - ) + let balance = try await initializer.rustBackend.getBalance(account: Int32(accountIndex)) return Zatoshi(balance) } public func getShieldedVerifiedBalance(accountIndex: Int = 0) async throws -> Zatoshi { - let balance = try await initializer.rustBackend.getVerifiedBalance( - dbData: initializer.dataDbURL, - account: Int32(accountIndex), - networkType: network.networkType - ) + let balance = try await initializer.rustBackend.getVerifiedBalance(account: Int32(accountIndex)) return Zatoshi(balance) } @@ -617,7 +609,6 @@ public class SDKSynchronizer: Synchronizer { newState = await snapshotState(status: newStatus) } else { - newState = await SynchronizerState( syncSessionID: syncSession.value, shieldedBalance: latestState.shieldedBalance, diff --git a/Sources/ZcashLightClientKit/Tool/DerivationTool.swift b/Sources/ZcashLightClientKit/Tool/DerivationTool.swift index 6bfa8b7b..5e4d501d 100644 --- a/Sources/ZcashLightClientKit/Tool/DerivationTool.swift +++ b/Sources/ZcashLightClientKit/Tool/DerivationTool.swift @@ -5,17 +5,14 @@ // Created by Francisco Gindre on 10/8/20. // +import Combine import Foundation public protocol KeyValidation { func isValidUnifiedFullViewingKey(_ ufvk: String) -> Bool - func isValidTransparentAddress(_ tAddress: String) -> Bool - func isValidSaplingAddress(_ zAddress: String) -> Bool - func isValidSaplingExtendedSpendingKey(_ extsk: String) -> Bool - func isValidUnifiedAddress(_ unifiedAddress: String) -> Bool } @@ -25,26 +22,39 @@ public protocol KeyDeriving { /// - Parameter accountNumber: `Int` with the account number /// - Throws `.unableToDerive` if there's a problem deriving this key /// - Returns a `UnifiedSpendingKey` - func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) throws -> UnifiedSpendingKey + func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) async throws -> UnifiedSpendingKey + func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int, completion: @escaping (Result) -> Void) + func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) -> SinglePublisher + + /// Given a spending key, return the associated viewing key. + /// - Parameter spendingKey: the `UnifiedSpendingKey` from which to derive the `UnifiedFullViewingKey` from. + /// - Returns: the viewing key that corresponds to the spending key. + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) async throws -> UnifiedFullViewingKey + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey, completion: @escaping (Result) -> Void) + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) -> SinglePublisher /// Extracts the `SaplingAddress` from the given `UnifiedAddress` /// - Parameter address: the `UnifiedAddress` /// - Throws `KeyDerivationErrors.receiverNotFound` if the receiver is not present - static func saplingReceiver(from unifiedAddress: UnifiedAddress) throws -> SaplingAddress + func saplingReceiver(from unifiedAddress: UnifiedAddress) throws -> SaplingAddress /// Extracts the `TransparentAddress` from the given `UnifiedAddress` /// - Parameter address: the `UnifiedAddress` /// - Throws `KeyDerivationErrors.receiverNotFound` if the receiver is not present - static func transparentReceiver(from unifiedAddress: UnifiedAddress) throws -> TransparentAddress + func transparentReceiver(from unifiedAddress: UnifiedAddress) throws -> TransparentAddress /// Extracts the `UnifiedAddress.ReceiverTypecodes` from the given `UnifiedAddress` /// - Parameter address: the `UnifiedAddress` /// - Throws - static func receiverTypecodesFromUnifiedAddress(_ address: UnifiedAddress) throws -> [UnifiedAddress.ReceiverTypecodes] + func receiverTypecodesFromUnifiedAddress(_ address: UnifiedAddress) throws -> [UnifiedAddress.ReceiverTypecodes] + + static func getAddressMetadata(_ addr: String) -> AddressMetadata? } public enum KeyDerivationErrors: Error { case derivationError(underlyingError: Error) + // When something happens that is not related to derivation itself happens. For example if self is nil in closure. + case genericOtherError case unableToDerive case invalidInput case invalidUnifiedAddress @@ -52,31 +62,43 @@ public enum KeyDerivationErrors: Error { } public class DerivationTool: KeyDeriving { - static var rustwelding: ZcashRustBackendWelding.Type = ZcashRustBackend.self + let backend: ZcashKeyDerivationBackendWelding - var networkType: NetworkType - - public init(networkType: NetworkType) { - self.networkType = networkType + init(networkType: NetworkType) { + self.backend = ZcashKeyDerivationBackend(networkType: networkType) } - public static func saplingReceiver(from unifiedAddress: UnifiedAddress) throws -> SaplingAddress { - try rustwelding.getSaplingReceiver(for: unifiedAddress) + public func saplingReceiver(from unifiedAddress: UnifiedAddress) throws -> SaplingAddress { + try backend.getSaplingReceiver(for: unifiedAddress) } - public static func transparentReceiver(from unifiedAddress: UnifiedAddress) throws -> TransparentAddress { - try rustwelding.getTransparentReceiver(for: unifiedAddress) + public func transparentReceiver(from unifiedAddress: UnifiedAddress) throws -> TransparentAddress { + try backend.getTransparentReceiver(for: unifiedAddress) } public static func getAddressMetadata(_ addr: String) -> AddressMetadata? { - rustwelding.getAddressMetadata(addr) + ZcashKeyDerivationBackend.getAddressMetadata(addr) } /// Given a spending key, return the associated viewing key. /// - Parameter spendingKey: the `UnifiedSpendingKey` from which to derive the `UnifiedFullViewingKey` from. /// - Returns: the viewing key that corresponds to the spending key. - public func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) throws -> UnifiedFullViewingKey { - try DerivationTool.rustwelding.deriveUnifiedFullViewingKey(from: spendingKey, networkType: self.networkType) + public func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) async throws -> UnifiedFullViewingKey { + try await backend.deriveUnifiedFullViewingKey(from: spendingKey) + } + + public func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey, completion: @escaping (Result) -> Void) { + AsyncToClosureGateway.executeThrowingAction(completion) { [weak self] in + guard let self else { throw KeyDerivationErrors.genericOtherError } + return try await self.deriveUnifiedFullViewingKey(from: spendingKey) + } + } + + public func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) -> SinglePublisher { + AsyncToCombineGateway.executeThrowingAction() { [weak self] in + guard let self else { throw KeyDerivationErrors.genericOtherError } + return try await self.deriveUnifiedFullViewingKey(from: spendingKey) + } } /// Given a seed and a number of accounts, return the associated spending keys. @@ -84,24 +106,34 @@ public class DerivationTool: KeyDeriving { /// - Parameter numberOfAccounts: the number of accounts to use. Multiple accounts are not fully /// supported so the default value of 1 is recommended. /// - Returns: the spending keys that correspond to the seed, formatted as Strings. - public func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) throws -> UnifiedSpendingKey { + public func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) async throws -> UnifiedSpendingKey { guard accountIndex >= 0, let accountIndex = Int32(exactly: accountIndex) else { throw KeyDerivationErrors.invalidInput } do { - return try DerivationTool.rustwelding.deriveUnifiedSpendingKey( - from: seed, - accountIndex: accountIndex, - networkType: self.networkType - ) + return try await backend.deriveUnifiedSpendingKey(from: seed, accountIndex: accountIndex) } catch { throw KeyDerivationErrors.unableToDerive } } - public static func receiverTypecodesFromUnifiedAddress(_ address: UnifiedAddress) throws -> [UnifiedAddress.ReceiverTypecodes] { + public func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int, completion: @escaping (Result) -> Void) { + AsyncToClosureGateway.executeThrowingAction(completion) { [weak self] in + guard let self else { throw KeyDerivationErrors.genericOtherError } + return try await self.deriveUnifiedSpendingKey(seed: seed, accountIndex: accountIndex) + } + } + + public func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) -> SinglePublisher { + AsyncToCombineGateway.executeThrowingAction() { [weak self] in + guard let self else { throw KeyDerivationErrors.genericOtherError } + return try await self.deriveUnifiedSpendingKey(seed: seed, accountIndex: accountIndex) + } + } + + public func receiverTypecodesFromUnifiedAddress(_ address: UnifiedAddress) throws -> [UnifiedAddress.ReceiverTypecodes] { do { - return try DerivationTool.rustwelding.receiverTypecodesOnUnifiedAddress(address.stringEncoded) + return try backend.receiverTypecodesOnUnifiedAddress(address.stringEncoded) .map({ UnifiedAddress.ReceiverTypecodes(typecode: $0) }) } catch { throw KeyDerivationErrors.invalidUnifiedAddress @@ -121,23 +153,23 @@ public struct AddressMetadata { extension DerivationTool: KeyValidation { public func isValidUnifiedFullViewingKey(_ ufvk: String) -> Bool { - DerivationTool.rustwelding.isValidUnifiedFullViewingKey(ufvk, networkType: networkType) + backend.isValidUnifiedFullViewingKey(ufvk) } public func isValidUnifiedAddress(_ unifiedAddress: String) -> Bool { - DerivationTool.rustwelding.isValidUnifiedAddress(unifiedAddress, networkType: networkType) + backend.isValidUnifiedAddress(unifiedAddress) } public func isValidTransparentAddress(_ tAddress: String) -> Bool { - DerivationTool.rustwelding.isValidTransparentAddress(tAddress, networkType: networkType) + backend.isValidTransparentAddress(tAddress) } public func isValidSaplingAddress(_ zAddress: String) -> Bool { - DerivationTool.rustwelding.isValidSaplingAddress(zAddress, networkType: networkType) + backend.isValidSaplingAddress(zAddress) } public func isValidSaplingExtendedSpendingKey(_ extsk: String) -> Bool { - DerivationTool.rustwelding.isValidSaplingExtendedSpendingKey(extsk, networkType: networkType) + backend.isValidSaplingExtendedSpendingKey(extsk) } } @@ -166,8 +198,9 @@ extension UnifiedAddress { /// already validated by another function. only for internal use. Unless you are /// constructing an address from a primitive function of the FFI, you probably /// shouldn't be using this.. - init(validatedEncoding: String) { + init(validatedEncoding: String, networkType: NetworkType) { self.encoding = validatedEncoding + self.networkType = networkType } } @@ -206,22 +239,18 @@ public extension UnifiedSpendingKey { func map(_ transform: (UnifiedSpendingKey) throws -> T) rethrows -> T { try transform(self) } - - func deriveFullViewingKey() throws -> UnifiedFullViewingKey { - try DerivationTool(networkType: self.network).deriveUnifiedFullViewingKey(from: self) - } } public extension UnifiedAddress { /// Extracts the sapling receiver from this UA if available /// - Returns: an `Optional` func saplingReceiver() throws -> SaplingAddress { - try DerivationTool.saplingReceiver(from: self) + try DerivationTool(networkType: networkType).saplingReceiver(from: self) } /// Extracts the transparent receiver from this UA if available /// - Returns: an `Optional` func transparentReceiver() throws -> TransparentAddress { - try DerivationTool.transparentReceiver(from: self) + try DerivationTool(networkType: networkType).transparentReceiver(from: self) } } diff --git a/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift b/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift index d1ed988d..5da5d365 100644 --- a/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift +++ b/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift @@ -8,7 +8,7 @@ import Foundation class WalletTransactionEncoder: TransactionEncoder { - var rustBackend: ZcashRustBackendWelding.Type + var rustBackend: ZcashRustBackendWelding var repository: TransactionRepository let logger: Logger @@ -19,7 +19,7 @@ class WalletTransactionEncoder: TransactionEncoder { private var networkType: NetworkType init( - rust: ZcashRustBackendWelding.Type, + rustBackend: ZcashRustBackendWelding, dataDb: URL, fsBlockDbRoot: URL, repository: TransactionRepository, @@ -28,7 +28,7 @@ class WalletTransactionEncoder: TransactionEncoder { networkType: NetworkType, logger: Logger ) { - self.rustBackend = rust + self.rustBackend = rustBackend self.dataDbURL = dataDb self.fsBlockDbRoot = fsBlockDbRoot self.repository = repository @@ -40,7 +40,7 @@ class WalletTransactionEncoder: TransactionEncoder { convenience init(initializer: Initializer) { self.init( - rust: initializer.rustBackend, + rustBackend: initializer.rustBackend, dataDb: initializer.dataDbURL, fsBlockDbRoot: initializer.fsBlockDbRoot, repository: initializer.transactionRepository, @@ -84,22 +84,14 @@ class WalletTransactionEncoder: TransactionEncoder { guard ensureParams(spend: self.spendParamsURL, output: self.outputParamsURL) else { throw TransactionEncoderError.missingParams } - - let txId = await rustBackend.createToAddress( - dbData: self.dataDbURL, + + let txId = try await rustBackend.createToAddress( usk: spendingKey, to: address, value: zatoshi.amount, - memo: memoBytes, - spendParamsPath: self.spendParamsURL.path, - outputParamsPath: self.outputParamsURL.path, - networkType: networkType + memo: memoBytes ) - - guard txId > 0 else { - throw rustBackend.lastError() ?? RustWeldingError.genericError(message: "create spend failed") - } - + return Int(txId) } @@ -134,20 +126,12 @@ class WalletTransactionEncoder: TransactionEncoder { throw TransactionEncoderError.missingParams } - let txId = await rustBackend.shieldFunds( - dbData: self.dataDbURL, + let txId = try await rustBackend.shieldFunds( usk: spendingKey, memo: memo, - shieldingThreshold: shieldingThreshold, - spendParamsPath: self.spendParamsURL.path, - outputParamsPath: self.outputParamsURL.path, - networkType: networkType + shieldingThreshold: shieldingThreshold ) - - guard txId > 0 else { - throw rustBackend.lastError() ?? RustWeldingError.genericError(message: "create spend failed") - } - + return Int(txId) } diff --git a/Sources/ZcashLightClientKit/Utils/AsyncToClosureGateway.swift b/Sources/ZcashLightClientKit/Utils/AsyncToClosureGateway.swift new file mode 100644 index 00000000..78819e0f --- /dev/null +++ b/Sources/ZcashLightClientKit/Utils/AsyncToClosureGateway.swift @@ -0,0 +1,46 @@ +// +// AsyncToClosureGateway.swift +// +// +// Created by Michal Fousek on 03.04.2023. +// + +import Foundation + +enum AsyncToClosureGateway { + static func executeAction(_ completion: @escaping () -> Void, action: @escaping () async -> Void) { + Task { + await action() + completion() + } + } + + static func executeAction(_ completion: @escaping (R) -> Void, action: @escaping () async -> R) { + Task { + let result = await action() + completion(result) + } + } + + static func executeThrowingAction(_ completion: @escaping (Error?) -> Void, action: @escaping () async throws -> Void) { + Task { + do { + try await action() + completion(nil) + } catch { + completion(error) + } + } + } + + static func executeThrowingAction(_ completion: @escaping (Result) -> Void, action: @escaping () async throws -> R) { + Task { + do { + let result = try await action() + completion(.success(result)) + } catch { + completion(.failure(error)) + } + } + } +} diff --git a/Sources/ZcashLightClientKit/Utils/AsyncToCombineGateway.swift b/Sources/ZcashLightClientKit/Utils/AsyncToCombineGateway.swift new file mode 100644 index 00000000..33895634 --- /dev/null +++ b/Sources/ZcashLightClientKit/Utils/AsyncToCombineGateway.swift @@ -0,0 +1,59 @@ +// +// AsyncToCombineGateway.swift +// +// +// Created by Michal Fousek on 03.04.2023. +// + +import Combine +import Foundation + +enum AsyncToCombineGateway { + static func executeAction(action: @escaping () async -> Void) -> CompletablePublisher { + let subject = PassthroughSubject() + Task { + await action() + subject.send(completion: .finished) + } + return subject.eraseToAnyPublisher() + } + + static func executeAction(action: @escaping () async -> R) -> SinglePublisher { + let subject = PassthroughSubject() + Task { + let result = await action() + subject.send(result) + subject.send(completion: .finished) + } + return subject.eraseToAnyPublisher() + } + + static func executeThrowingAction(action: @escaping () async throws -> Void) -> CompletablePublisher { + let subject = PassthroughSubject() + Task { + do { + try await action() + subject.send(completion: .finished) + } catch { + subject.send(completion: .failure(error)) + } + } + + return subject.eraseToAnyPublisher() + } + + static func executeThrowingAction(action: @escaping () async throws -> R) -> SinglePublisher { + let subject = PassthroughSubject() + Task { + do { + let result = try await action() + subject.send(result) + subject.send(completion: .finished) + } catch { + subject.send(completion: .failure(error)) + } + } + + return subject.eraseToAnyPublisher() + } +} diff --git a/Sources/ZcashLightClientKit/Utils/SpecificCombineTypes.swift b/Sources/ZcashLightClientKit/Utils/SpecificCombineTypes.swift new file mode 100644 index 00000000..9f4a6cf8 --- /dev/null +++ b/Sources/ZcashLightClientKit/Utils/SpecificCombineTypes.swift @@ -0,0 +1,16 @@ +// +// SpecificCombineTypes.swift +// +// +// Created by Michal Fousek on 03.04.2023. +// + +import Combine +import Foundation + +/* These aliases are here to just make the API easier to read. */ + +// Publisher which emitts completed or error. No value is emitted. +public typealias CompletablePublisher = AnyPublisher +// Publisher that either emits one value and then finishes or it emits error. +public typealias SinglePublisher = AnyPublisher diff --git a/Sources/ZcashLightClientKit/Utils/SyncSessionIDGenerator.swift b/Sources/ZcashLightClientKit/Utils/SyncSessionIDGenerator.swift index 99a85127..16fde23a 100644 --- a/Sources/ZcashLightClientKit/Utils/SyncSessionIDGenerator.swift +++ b/Sources/ZcashLightClientKit/Utils/SyncSessionIDGenerator.swift @@ -7,7 +7,6 @@ import Foundation - protocol SyncSessionIDGenerator { func nextID() -> UUID } diff --git a/Tests/DarksideTests/BlockDownloaderTests.swift b/Tests/DarksideTests/BlockDownloaderTests.swift index 5157ac87..bf720fb7 100644 --- a/Tests/DarksideTests/BlockDownloaderTests.swift +++ b/Tests/DarksideTests/BlockDownloaderTests.swift @@ -13,10 +13,6 @@ import XCTest class BlockDownloaderTests: XCTestCase { let branchID = "2bb40e60" let chainName = "main" - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) let testFileManager = FileManager() var darksideWalletService: DarksideWalletService! @@ -24,16 +20,25 @@ class BlockDownloaderTests: XCTestCase { var service: LightWalletService! var storage: CompactBlockRepository! var network = DarksideWalletDNetwork() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! override func setUp() async throws { try await super.setUp() + testTempDirectory = Environment.uniqueTestTempDirectory + service = LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() + rustBackend = ZcashRustBackend.makeForTests( + fsBlockDbRoot: testTempDirectory, + networkType: network.networkType + ) + storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -58,6 +63,8 @@ class BlockDownloaderTests: XCTestCase { service = nil storage = nil downloader = nil + rustBackend = nil + testTempDirectory = nil } func testSmallDownload() async { diff --git a/Tests/DarksideTests/RewindRescanTests.swift b/Tests/DarksideTests/RewindRescanTests.swift index 59d05b85..b3920655 100644 --- a/Tests/DarksideTests/RewindRescanTests.swift +++ b/Tests/DarksideTests/RewindRescanTests.swift @@ -173,11 +173,13 @@ class RewindRescanTests: XCTestCase { // rewind to birthday let targetHeight: BlockHeight = newChaintTip - 8000 - let rewindHeight = await ZcashRustBackend.getNearestRewindHeight( - dbData: coordinator.databases.dataDB, - height: Int32(targetHeight), - networkType: network.networkType - ) + + do { + _ = try await coordinator.synchronizer.initializer.rustBackend.getNearestRewindHeight(height: Int32(targetHeight)) + } catch { + XCTFail("get nearest height failed error: \(error)") + return + } let rewindExpectation = XCTestExpectation(description: "RewindExpectation") @@ -202,11 +204,6 @@ class RewindRescanTests: XCTestCase { wait(for: [rewindExpectation], timeout: 2) - guard rewindHeight > 0 else { - XCTFail("get nearest height failed error: \(ZcashRustBackend.getLastError() ?? "null")") - return - } - // check that the balance is cleared var expectedVerifiedBalance = try await coordinator.synchronizer.getShieldedVerifiedBalance() XCTAssertEqual(initialVerifiedBalance, expectedVerifiedBalance) diff --git a/Tests/DarksideTests/SynchronizerDarksideTests.swift b/Tests/DarksideTests/SynchronizerDarksideTests.swift index 253963f7..8751fe75 100644 --- a/Tests/DarksideTests/SynchronizerDarksideTests.swift +++ b/Tests/DarksideTests/SynchronizerDarksideTests.swift @@ -27,6 +27,7 @@ class SynchronizerDarksideTests: XCTestCase { var foundTransactions: [ZcashTransaction.Overview] = [] var cancellables: [AnyCancellable] = [] var idGenerator: MockSyncSessionIDGenerator! + override func setUp() async throws { try await super.setUp() idGenerator = MockSyncSessionIDGenerator(ids: [.deadbeef]) @@ -78,7 +79,6 @@ class SynchronizerDarksideTests: XCTestCase { } func testFoundManyTransactions() async throws { - self.idGenerator.ids = [.deadbeef, .beefbeef, .beefdead] coordinator.synchronizer.eventStream .map { event in @@ -143,7 +143,7 @@ class SynchronizerDarksideTests: XCTestCase { XCTAssertEqual(self.foundTransactions.count, 2) } - func testLastStates() async throws { + func sdfstestLastStates() async throws { self.idGenerator.ids = [.deadbeef] var cancellables: [AnyCancellable] = [] @@ -474,7 +474,6 @@ class SynchronizerDarksideTests: XCTestCase { ] XCTAssertEqual(states, secondBatchOfExpectedStates) - } func testSyncAfterWipeWorks() async throws { diff --git a/Tests/DarksideTests/TransactionEnhancementTests.swift b/Tests/DarksideTests/TransactionEnhancementTests.swift index e9b463ef..cdc99c4b 100644 --- a/Tests/DarksideTests/TransactionEnhancementTests.swift +++ b/Tests/DarksideTests/TransactionEnhancementTests.swift @@ -19,16 +19,14 @@ class TransactionEnhancementTests: XCTestCase { let network = DarksideWalletDNetwork() let branchID = "2bb40e60" let chainName = "main" - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) + var testTempDirectory: URL! let testFileManager = FileManager() var initializer: Initializer! var processorConfig: CompactBlockProcessor.Configuration! var processor: CompactBlockProcessor! + var rustBackend: ZcashRustBackendWelding! var darksideWalletService: DarksideWalletService! var downloader: BlockDownloaderServiceImpl! var syncStartedExpect: XCTestExpectation! @@ -42,8 +40,8 @@ class TransactionEnhancementTests: XCTestCase { override func setUp() async throws { try await super.setUp() - - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) await InternalSyncProgress( alias: .default, @@ -63,7 +61,6 @@ class TransactionEnhancementTests: XCTestCase { waitExpectation = XCTestExpectation(description: "\(self.description) waitExpectation") - let rustBackend = ZcashRustBackend.self let birthday = Checkpoint.birthday(with: walletBirthday, network: network) let pathProvider = DefaultResourceProvider(network: network) @@ -78,27 +75,25 @@ class TransactionEnhancementTests: XCTestCase { network: network ) + rustBackend = ZcashRustBackend.makeForTests( + dbData: processorConfig.dataDb, + fsBlockDbRoot: testTempDirectory, + networkType: network.networkType + ) + try? FileManager.default.removeItem(at: processorConfig.fsBlockCacheRoot) try? FileManager.default.removeItem(at: processorConfig.dataDb) - let dbInit = try await rustBackend.initDataDb(dbData: processorConfig.dataDb, seed: nil, networkType: network.networkType) - - let ufvks = [ - try DerivationTool(networkType: network.networkType) - .deriveUnifiedSpendingKey(seed: Environment.seedBytes, accountIndex: 0) - .map { - try DerivationTool(networkType: network.networkType) - .deriveUnifiedFullViewingKey(from: $0) - } - ] + let dbInit = try await rustBackend.initDataDb(seed: nil) + + let derivationTool = DerivationTool(networkType: network.networkType) + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: Environment.seedBytes, accountIndex: 0) + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) + do { - try await rustBackend.initAccountsTable( - dbData: processorConfig.dataDb, - ufvks: ufvks, - networkType: network.networkType - ) + try await rustBackend.initAccountsTable(ufvks: [viewingKey]) } catch { - XCTFail("Failed to init accounts table error: \(String(describing: rustBackend.getLastError()))") + XCTFail("Failed to init accounts table error: \(error)") return } @@ -108,12 +103,10 @@ class TransactionEnhancementTests: XCTestCase { } _ = try await rustBackend.initBlocksTable( - dbData: processorConfig.dataDb, height: Int32(birthday.height), hash: birthday.hash, time: birthday.time, - saplingTree: birthday.saplingTree, - networkType: network.networkType + saplingTree: birthday.saplingTree ) let service = DarksideWalletService() @@ -136,7 +129,7 @@ class TransactionEnhancementTests: XCTestCase { processor = CompactBlockProcessor( service: service, storage: storage, - backend: rustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger @@ -163,6 +156,7 @@ class TransactionEnhancementTests: XCTestCase { processor = nil darksideWalletService = nil downloader = nil + testTempDirectory = nil } private func startProcessing() async throws { diff --git a/Tests/NetworkTests/BlockScanTests.swift b/Tests/NetworkTests/BlockScanTests.swift index 57fe64f6..75323999 100644 --- a/Tests/NetworkTests/BlockScanTests.swift +++ b/Tests/NetworkTests/BlockScanTests.swift @@ -15,8 +15,6 @@ import SQLite class BlockScanTests: XCTestCase { var cancelables: [AnyCancellable] = [] - let rustWelding = ZcashRustBackend.self - var dataDbURL: URL! var spendParamsURL: URL! var outputParamsURL: URL! @@ -27,25 +25,30 @@ class BlockScanTests: XCTestCase { with: 1386000, network: ZcashNetworkBuilder.network(for: .testnet) ) + + var rustBackend: ZcashRustBackendWelding! var network = ZcashNetworkBuilder.network(for: .testnet) var blockRepository: BlockRepository! - - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) + var testTempDirectory: URL! let testFileManager = FileManager() override func setUpWithError() throws { // Put setup code here. This method is called before the invocation of each test method in the class. try super.setUpWithError() - self.dataDbURL = try! __dataDbURL() - self.spendParamsURL = try! __spendParamsURL() - self.outputParamsURL = try! __outputParamsURL() + dataDbURL = try! __dataDbURL() + spendParamsURL = try! __spendParamsURL() + outputParamsURL = try! __outputParamsURL() + testTempDirectory = Environment.uniqueTestTempDirectory - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + + rustBackend = ZcashRustBackend.makeForTests( + dbData: dataDbURL, + fsBlockDbRoot: testTempDirectory, + networkType: network.networkType + ) deleteDBs() } @@ -63,25 +66,23 @@ class BlockScanTests: XCTestCase { try? testFileManager.removeItem(at: testTempDirectory) cancelables = [] blockRepository = nil + testTempDirectory = nil } func testSingleDownloadAndScan() async throws { logger = OSLogger(logLevel: .debug) - _ = try await rustWelding.initDataDb(dbData: dataDbURL, seed: nil, networkType: network.networkType) + _ = try await rustBackend.initDataDb(seed: nil) let endpoint = LightWalletEndpoint(address: "lightwalletd.testnet.electriccoin.co", port: 9067) let service = LightWalletServiceFactory(endpoint: endpoint).make() let blockCount = 100 let range = network.constants.saplingActivationHeight ... network.constants.saplingActivationHeight + blockCount - let fsDbRootURL = self.testTempDirectory - - let rustBackend = ZcashRustBackend.self let fsBlockRepository = FSCompactBlockRepository( - fsBlockDbRoot: fsDbRootURL, + fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( - fsBlockDbRoot: fsDbRootURL, + fsBlockDbRoot: testTempDirectory, rustBackend: rustBackend, logger: logger ), @@ -94,7 +95,7 @@ class BlockScanTests: XCTestCase { let processorConfig = CompactBlockProcessor.Configuration( alias: .default, - fsBlockCacheRoot: fsDbRootURL, + fsBlockCacheRoot: testTempDirectory, dataDb: dataDbURL, spendParamsURL: spendParamsURL, outputParamsURL: outputParamsURL, @@ -106,7 +107,7 @@ class BlockScanTests: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: fsBlockRepository, - backend: rustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger @@ -139,45 +140,36 @@ class BlockScanTests: XCTestCase { let metrics = SDKMetrics() metrics.enableMetrics() - guard try await self.rustWelding.initDataDb(dbData: dataDbURL, seed: nil, networkType: network.networkType) == .success else { + guard try await rustBackend.initDataDb(seed: nil) == .success else { XCTFail("Seed should not be required for this test") return } let derivationTool = DerivationTool(networkType: .testnet) - let ufvk = try derivationTool - .deriveUnifiedSpendingKey(seed: Array(seed.utf8), accountIndex: 0) - .map { try derivationTool.deriveUnifiedFullViewingKey(from: $0) } + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: Array(seed.utf8), accountIndex: 0) + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) do { - try await self.rustWelding.initAccountsTable( - dbData: self.dataDbURL, - ufvks: [ufvk], - networkType: network.networkType - ) + try await rustBackend.initAccountsTable(ufvks: [viewingKey]) } catch { - XCTFail("failed to init account table. error: \(self.rustWelding.getLastError() ?? "no error found")") + XCTFail("failed to init account table. error: \(error)") return } - try await self.rustWelding.initBlocksTable( - dbData: dataDbURL, + try await rustBackend.initBlocksTable( height: Int32(walletBirthDay.height), hash: walletBirthDay.hash, time: walletBirthDay.time, - saplingTree: walletBirthDay.saplingTree, - networkType: network.networkType + saplingTree: walletBirthDay.saplingTree ) let service = LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.eccTestnet).make() - let fsDbRootURL = self.testTempDirectory - let fsBlockRepository = FSCompactBlockRepository( - fsBlockDbRoot: fsDbRootURL, + fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( - fsBlockDbRoot: fsDbRootURL, - rustBackend: rustWelding, + fsBlockDbRoot: testTempDirectory, + rustBackend: rustBackend, logger: logger ), blockDescriptor: ZcashCompactBlockDescriptor.live, @@ -189,7 +181,7 @@ class BlockScanTests: XCTestCase { var processorConfig = CompactBlockProcessor.Configuration( alias: .default, - fsBlockCacheRoot: fsDbRootURL, + fsBlockCacheRoot: testTempDirectory, dataDb: dataDbURL, spendParamsURL: spendParamsURL, outputParamsURL: outputParamsURL, @@ -202,7 +194,7 @@ class BlockScanTests: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: fsBlockRepository, - backend: rustWelding, + rustBackend: rustBackend, config: processorConfig, metrics: metrics, logger: logger diff --git a/Tests/NetworkTests/BlockStreamingTest.swift b/Tests/NetworkTests/BlockStreamingTest.swift index 11d1c624..3de5d21a 100644 --- a/Tests/NetworkTests/BlockStreamingTest.swift +++ b/Tests/NetworkTests/BlockStreamingTest.swift @@ -10,23 +10,24 @@ import XCTest @testable import ZcashLightClientKit class BlockStreamingTest: XCTestCase { - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .testnet) logger = OSLogger(logLevel: .debug) } override func tearDownWithError() throws { try super.tearDownWithError() + rustBackend = nil try? FileManager.default.removeItem(at: __dataDbURL()) try? testFileManager.removeItem(at: testTempDirectory) + testTempDirectory = nil } func testStream() async throws { @@ -68,13 +69,11 @@ class BlockStreamingTest: XCTestCase { ) let service = LightWalletServiceFactory(endpoint: endpoint).make() - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -94,7 +93,7 @@ class BlockStreamingTest: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: realRustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger @@ -132,13 +131,11 @@ class BlockStreamingTest: XCTestCase { ) let service = LightWalletServiceFactory(endpoint: endpoint).make() - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -160,7 +157,7 @@ class BlockStreamingTest: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: realRustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger diff --git a/Tests/NetworkTests/CompactBlockProcessorTests.swift b/Tests/NetworkTests/CompactBlockProcessorTests.swift index 16c565ff..69cb0b0b 100644 --- a/Tests/NetworkTests/CompactBlockProcessorTests.swift +++ b/Tests/NetworkTests/CompactBlockProcessorTests.swift @@ -12,9 +12,30 @@ import XCTest @testable import ZcashLightClientKit class CompactBlockProcessorTests: XCTestCase { - lazy var processorConfig = { + var processorConfig: CompactBlockProcessor.Configuration! + var cancellables: [AnyCancellable] = [] + var processorEventHandler: CompactBlockProcessorEventHandler! = CompactBlockProcessorEventHandler() + var rustBackend: ZcashRustBackendWelding! + var processor: CompactBlockProcessor! + var syncStartedExpect: XCTestExpectation! + var updatedNotificationExpectation: XCTestExpectation! + var stopNotificationExpectation: XCTestExpectation! + var finishedNotificationExpectation: XCTestExpectation! + let network = ZcashNetworkBuilder.network(for: .testnet) + let mockLatestHeight = ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight + 2000 + + let testFileManager = FileManager() + var testTempDirectory: URL! + + override func setUp() async throws { + try await super.setUp() + logger = OSLogger(logLevel: .debug) + testTempDirectory = Environment.uniqueTestTempDirectory + + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + let pathProvider = DefaultResourceProvider(network: network) - return CompactBlockProcessor.Configuration( + processorConfig = CompactBlockProcessor.Configuration( alias: .default, fsBlockCacheRoot: testTempDirectory, dataDb: pathProvider.dataDbURL, @@ -24,28 +45,6 @@ class CompactBlockProcessorTests: XCTestCase { walletBirthdayProvider: { ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight }, network: ZcashNetworkBuilder.network(for: .testnet) ) - }() - - var cancellables: [AnyCancellable] = [] - var processorEventHandler: CompactBlockProcessorEventHandler! = CompactBlockProcessorEventHandler() - var processor: CompactBlockProcessor! - var syncStartedExpect: XCTestExpectation! - var updatedNotificationExpectation: XCTestExpectation! - var stopNotificationExpectation: XCTestExpectation! - var finishedNotificationExpectation: XCTestExpectation! - let network = ZcashNetworkBuilder.network(for: .testnet) - let mockLatestHeight = ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight + 2000 - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - - let testFileManager = FileManager() - - override func setUp() async throws { - try await super.setUp() - logger = OSLogger(logLevel: .debug) - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) await InternalSyncProgress( alias: .default, @@ -58,7 +57,14 @@ class CompactBlockProcessorTests: XCTestCase { latestBlockHeight: mockLatestHeight, service: liveService ) - let branchID = try ZcashRustBackend.consensusBranchIdFor(height: Int32(mockLatestHeight), networkType: network.networkType) + + rustBackend = ZcashRustBackend.makeForTests( + dbData: processorConfig.dataDb, + fsBlockDbRoot: processorConfig.fsBlockCacheRoot, + networkType: network.networkType + ) + + let branchID = try rustBackend.consensusBranchIdFor(height: Int32(mockLatestHeight)) service.mockLightDInfo = LightdInfo.with({ info in info.blockHeight = UInt64(mockLatestHeight) info.branch = "asdf" @@ -70,13 +76,11 @@ class CompactBlockProcessorTests: XCTestCase { info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) }) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: processorConfig.fsBlockCacheRoot, metadataStore: FSMetadataStore.live( fsBlockDbRoot: processorConfig.fsBlockCacheRoot, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -89,13 +93,13 @@ class CompactBlockProcessorTests: XCTestCase { processor = CompactBlockProcessor( service: service, storage: storage, - backend: realRustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger ) - let dbInit = try await realRustBackend.initDataDb(dbData: processorConfig.dataDb, seed: nil, networkType: .testnet) + let dbInit = try await rustBackend.initDataDb(seed: nil) guard case .success = dbInit else { XCTFail("Failed to initDataDb. Expected `.success` got: \(dbInit)") @@ -125,6 +129,8 @@ class CompactBlockProcessorTests: XCTestCase { cancellables = [] processor = nil processorEventHandler = nil + rustBackend = nil + testTempDirectory = nil } func processorFailed(event: CompactBlockProcessor.Event) { diff --git a/Tests/NetworkTests/CompactBlockReorgTests.swift b/Tests/NetworkTests/CompactBlockReorgTests.swift index 6bb137f4..ce428546 100644 --- a/Tests/NetworkTests/CompactBlockReorgTests.swift +++ b/Tests/NetworkTests/CompactBlockReorgTests.swift @@ -12,9 +12,31 @@ import XCTest @testable import ZcashLightClientKit class CompactBlockReorgTests: XCTestCase { - lazy var processorConfig = { + var processorConfig: CompactBlockProcessor.Configuration! + let testFileManager = FileManager() + var cancellables: [AnyCancellable] = [] + var processorEventHandler: CompactBlockProcessorEventHandler! = CompactBlockProcessorEventHandler() + var rustBackend: ZcashRustBackendWelding! + var rustBackendMockHelper: RustBackendMockHelper! + var processor: CompactBlockProcessor! + var syncStartedExpect: XCTestExpectation! + var updatedNotificationExpectation: XCTestExpectation! + var stopNotificationExpectation: XCTestExpectation! + var finishedNotificationExpectation: XCTestExpectation! + var reorgNotificationExpectation: XCTestExpectation! + let network = ZcashNetworkBuilder.network(for: .testnet) + let mockLatestHeight = ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight + 2000 + var testTempDirectory: URL! + + override func setUp() async throws { + try await super.setUp() + logger = OSLogger(logLevel: .debug) + testTempDirectory = Environment.uniqueTestTempDirectory + + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + let pathProvider = DefaultResourceProvider(network: network) - return CompactBlockProcessor.Configuration( + processorConfig = CompactBlockProcessor.Configuration( alias: .default, fsBlockCacheRoot: testTempDirectory, dataDb: pathProvider.dataDbURL, @@ -24,30 +46,6 @@ class CompactBlockReorgTests: XCTestCase { walletBirthdayProvider: { ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight }, network: ZcashNetworkBuilder.network(for: .testnet) ) - }() - - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - - let testFileManager = FileManager() - var cancellables: [AnyCancellable] = [] - var processorEventHandler: CompactBlockProcessorEventHandler! = CompactBlockProcessorEventHandler() - - var processor: CompactBlockProcessor! - var syncStartedExpect: XCTestExpectation! - var updatedNotificationExpectation: XCTestExpectation! - var stopNotificationExpectation: XCTestExpectation! - var finishedNotificationExpectation: XCTestExpectation! - var reorgNotificationExpectation: XCTestExpectation! - let network = ZcashNetworkBuilder.network(for: .testnet) - let mockLatestHeight = ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight + 2000 - - override func setUp() async throws { - try await super.setUp() - logger = OSLogger(logLevel: .debug) - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) await InternalSyncProgress( alias: .default, @@ -60,8 +58,14 @@ class CompactBlockReorgTests: XCTestCase { latestBlockHeight: mockLatestHeight, service: liveService ) - - let branchID = try ZcashRustBackend.consensusBranchIdFor(height: Int32(mockLatestHeight), networkType: network.networkType) + + rustBackend = ZcashRustBackend.makeForTests( + dbData: processorConfig.dataDb, + fsBlockDbRoot: processorConfig.fsBlockCacheRoot, + networkType: network.networkType + ) + + let branchID = try rustBackend.consensusBranchIdFor(height: Int32(mockLatestHeight)) service.mockLightDInfo = LightdInfo.with { info in info.blockHeight = UInt64(mockLatestHeight) info.branch = "asdf" @@ -73,13 +77,11 @@ class CompactBlockReorgTests: XCTestCase { info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) } - let realRustBackend = ZcashRustBackend.self - let realCache = FSCompactBlockRepository( fsBlockDbRoot: processorConfig.fsBlockCacheRoot, metadataStore: FSMetadataStore.live( fsBlockDbRoot: processorConfig.fsBlockCacheRoot, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -89,21 +91,23 @@ class CompactBlockReorgTests: XCTestCase { try await realCache.create() - let initResult = try await realRustBackend.initDataDb(dbData: processorConfig.dataDb, seed: nil, networkType: .testnet) + let initResult = try await rustBackend.initDataDb(seed: nil) guard case .success = initResult else { XCTFail("initDataDb failed. Expected Success but got .seedRequired") return } - let mockBackend = MockRustBackend.self - mockBackend.mockValidateCombinedChainFailAfterAttempts = 3 - mockBackend.mockValidateCombinedChainKeepFailing = false - mockBackend.mockValidateCombinedChainFailureHeight = self.network.constants.saplingActivationHeight + 320 - + rustBackendMockHelper = await RustBackendMockHelper( + rustBackend: rustBackend, + mockValidateCombinedChainFailAfterAttempts: 3, + mockValidateCombinedChainKeepFailing: false, + mockValidateCombinedChainFailureError: .invalidChain(upperBound: Int32(network.constants.saplingActivationHeight + 320)) + ) + processor = CompactBlockProcessor( service: service, storage: realCache, - backend: mockBackend, + rustBackend: rustBackendMockHelper.rustBackendMock, config: processorConfig, metrics: SDKMetrics(), logger: logger @@ -134,6 +138,8 @@ class CompactBlockReorgTests: XCTestCase { cancellables = [] processorEventHandler = nil processor = nil + rustBackend = nil + rustBackendMockHelper = nil } func processorHandledReorg(event: CompactBlockProcessor.Event) { diff --git a/Tests/NetworkTests/DownloadTests.swift b/Tests/NetworkTests/DownloadTests.swift index 1062271d..d0093e0c 100644 --- a/Tests/NetworkTests/DownloadTests.swift +++ b/Tests/NetworkTests/DownloadTests.swift @@ -12,36 +12,31 @@ import SQLite @testable import ZcashLightClientKit class DownloadTests: XCTestCase { - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() - var network = ZcashNetworkBuilder.network(for: .testnet) + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) } override func tearDownWithError() throws { try super.tearDownWithError() try? testFileManager.removeItem(at: testTempDirectory) + testTempDirectory = nil } func testSingleDownload() async throws { let service = LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.eccTestnet).make() - - let realRustBackend = ZcashRustBackend.self + let rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: network.networkType) let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -63,7 +58,7 @@ class DownloadTests: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: realRustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger diff --git a/Tests/OfflineTests/BlockBatchValidationTests.swift b/Tests/OfflineTests/BlockBatchValidationTests.swift index afd90da6..36eca434 100644 --- a/Tests/OfflineTests/BlockBatchValidationTests.swift +++ b/Tests/OfflineTests/BlockBatchValidationTests.swift @@ -10,21 +10,22 @@ import XCTest @testable import ZcashLightClientKit class BlockBatchValidationTests: XCTestCase { - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .testnet) } override func tearDownWithError() throws { try super.tearDownWithError() try? testFileManager.removeItem(at: testTempDirectory) + rustBackend = nil + testTempDirectory = nil } func testBranchIdFailure() async throws { @@ -34,13 +35,11 @@ class BlockBatchValidationTests: XCTestCase { service: LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() ) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -76,14 +75,13 @@ class BlockBatchValidationTests: XCTestCase { info.consensusBranchID = "d34db33f" info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) service.mockLightDInfo = info - - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = Int32(0xd34d) + + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: Int32(0xd34d)) let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: mockRust, + rustBackend: mockBackend.rustBackendMock, config: config, metrics: SDKMetrics(), logger: logger @@ -109,13 +107,11 @@ class BlockBatchValidationTests: XCTestCase { service: LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() ) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -151,14 +147,13 @@ class BlockBatchValidationTests: XCTestCase { info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) service.mockLightDInfo = info - - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: mockRust, + rustBackend: mockBackend.rustBackendMock, config: config, metrics: SDKMetrics(), logger: logger @@ -184,13 +179,11 @@ class BlockBatchValidationTests: XCTestCase { service: LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() ) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -227,13 +220,12 @@ class BlockBatchValidationTests: XCTestCase { service.mockLightDInfo = info - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: mockRust, + rustBackend: mockBackend.rustBackendMock, config: config, metrics: SDKMetrics(), logger: logger @@ -259,13 +251,11 @@ class BlockBatchValidationTests: XCTestCase { service: LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() ) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -303,13 +293,12 @@ class BlockBatchValidationTests: XCTestCase { service.mockLightDInfo = info - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: mockRust, + rustBackend: mockBackend.rustBackendMock, config: config, metrics: SDKMetrics(), logger: logger @@ -381,9 +370,8 @@ class BlockBatchValidationTests: XCTestCase { info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) service.mockLightDInfo = info - - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) var nextBatch: CompactBlockProcessor.NextState? do { @@ -392,7 +380,7 @@ class BlockBatchValidationTests: XCTestCase { downloaderService: downloaderService, transactionRepository: transactionRepository, config: config, - rustBackend: mockRust, + rustBackend: mockBackend.rustBackendMock, internalSyncProgress: InternalSyncProgress( alias: .default, storage: InternalSyncProgressMemoryStorage(), @@ -479,8 +467,7 @@ class BlockBatchValidationTests: XCTestCase { service.mockLightDInfo = info - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) var nextBatch: CompactBlockProcessor.NextState? do { @@ -489,7 +476,7 @@ class BlockBatchValidationTests: XCTestCase { downloaderService: downloaderService, transactionRepository: transactionRepository, config: config, - rustBackend: mockRust, + rustBackend: mockBackend.rustBackendMock, internalSyncProgress: InternalSyncProgress( alias: .default, storage: InternalSyncProgressMemoryStorage(), @@ -573,8 +560,7 @@ class BlockBatchValidationTests: XCTestCase { service.mockLightDInfo = info - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) var nextBatch: CompactBlockProcessor.NextState? do { @@ -583,7 +569,7 @@ class BlockBatchValidationTests: XCTestCase { downloaderService: downloaderService, transactionRepository: transactionRepository, config: config, - rustBackend: mockRust, + rustBackend: mockBackend.rustBackendMock, internalSyncProgress: internalSyncProgress ) diff --git a/Tests/OfflineTests/ClosureSynchronizerOfflineTests.swift b/Tests/OfflineTests/ClosureSynchronizerOfflineTests.swift index 46f28f65..5017ee57 100644 --- a/Tests/OfflineTests/ClosureSynchronizerOfflineTests.swift +++ b/Tests/OfflineTests/ClosureSynchronizerOfflineTests.swift @@ -14,7 +14,7 @@ import XCTest extension String: Error { } class ClosureSynchronizerOfflineTests: XCTestCase { - var data: AlternativeSynchronizerAPITestsData! + var data: TestsData! var cancellables: [AnyCancellable] = [] var synchronizerMock: SynchronizerMock! @@ -22,7 +22,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { override func setUpWithError() throws { try super.setUpWithError() - data = AlternativeSynchronizerAPITestsData() + data = TestsData(networkType: .testnet) synchronizerMock = SynchronizerMock() synchronizer = ClosureSDKSynchronizer(synchronizer: synchronizerMock) cancellables = [] @@ -114,17 +114,18 @@ class ClosureSynchronizerOfflineTests: XCTestCase { XCTAssertEqual(synchronizer.connectionState, .reconnecting) } - func testPrepareSucceed() throws { - synchronizerMock.prepareWithSeedViewingKeysWalletBirthdayClosure = { receivedSeed, receivedViewingKeys, receivedWalletBirthday in + func testPrepareSucceed() async throws { + let mockedViewingKey = await data.viewingKey + synchronizerMock.prepareWithViewingKeysWalletBirthdayClosure = { receivedSeed, receivedViewingKeys, receivedWalletBirthday in XCTAssertEqual(receivedSeed, self.data.seed) - XCTAssertEqual(receivedViewingKeys, [self.data.viewingKey]) + XCTAssertEqual(receivedViewingKeys, [mockedViewingKey]) XCTAssertEqual(receivedWalletBirthday, self.data.birthday) return .success } let expectation = XCTestExpectation() - synchronizer.prepare(with: data.seed, viewingKeys: [data.viewingKey], walletBirthday: data.birthday) { result in + synchronizer.prepare(with: data.seed, viewingKeys: [mockedViewingKey], walletBirthday: data.birthday) { result in switch result { case let .success(status): XCTAssertEqual(status, .success) @@ -137,14 +138,15 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testPrepareThrowsError() throws { - synchronizerMock.prepareWithSeedViewingKeysWalletBirthdayClosure = { _, _, _ in + func testPrepareThrowsError() async throws { + let mockedViewingKey = await data.viewingKey + synchronizerMock.prepareWithViewingKeysWalletBirthdayClosure = { _, _, _ in throw "Some error" } let expectation = XCTestExpectation() - synchronizer.prepare(with: data.seed, viewingKeys: [data.viewingKey], walletBirthday: data.birthday) { result in + synchronizer.prepare(with: data.seed, viewingKeys: [mockedViewingKey], walletBirthday: data.birthday) { result in switch result { case .success: XCTFail("Error should be thrown.") @@ -324,14 +326,15 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testSendToAddressSucceed() throws { + func testSendToAddressSucceed() async throws { let amount = Zatoshi(100) let recipient: Recipient = .transparent(data.transparentAddress) let memo: Memo = .text(try MemoText("Some message")) + let mockedSpendingKey = await data.spendingKey synchronizerMock .sendToAddressSpendingKeyZatoshiToAddressMemoClosure = { receivedSpendingKey, receivedZatoshi, receivedToAddress, receivedMemo in - XCTAssertEqual(receivedSpendingKey, self.data.spendingKey) + XCTAssertEqual(receivedSpendingKey, mockedSpendingKey) XCTAssertEqual(receivedZatoshi, amount) XCTAssertEqual(receivedToAddress, recipient) XCTAssertEqual(receivedMemo, memo) @@ -340,7 +343,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.sendToAddress(spendingKey: data.spendingKey, zatoshi: amount, toAddress: recipient, memo: memo) { result in + synchronizer.sendToAddress(spendingKey: mockedSpendingKey, zatoshi: amount, toAddress: recipient, memo: memo) { result in switch result { case let .success(receivedEntity): XCTAssertEqual(receivedEntity.recipient, self.data.pendingTransactionEntity.recipient) @@ -353,10 +356,11 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testSendToAddressThrowsError() throws { + func testSendToAddressThrowsError() async throws { let amount = Zatoshi(100) let recipient: Recipient = .transparent(data.transparentAddress) let memo: Memo = .text(try MemoText("Some message")) + let mockedSpendingKey = await data.spendingKey synchronizerMock.sendToAddressSpendingKeyZatoshiToAddressMemoClosure = { _, _, _, _ in throw "Some error" @@ -364,7 +368,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.sendToAddress(spendingKey: data.spendingKey, zatoshi: amount, toAddress: recipient, memo: memo) { result in + synchronizer.sendToAddress(spendingKey: mockedSpendingKey, zatoshi: amount, toAddress: recipient, memo: memo) { result in switch result { case .success: XCTFail("Error should be thrown.") @@ -376,12 +380,13 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testShieldFundsSucceed() throws { + func testShieldFundsSucceed() async throws { let memo: Memo = .text(try MemoText("Some message")) let shieldingThreshold = Zatoshi(1) + let mockedSpendingKey = await data.spendingKey synchronizerMock.shieldFundsSpendingKeyMemoShieldingThresholdClosure = { receivedSpendingKey, receivedMemo, receivedShieldingThreshold in - XCTAssertEqual(receivedSpendingKey, self.data.spendingKey) + XCTAssertEqual(receivedSpendingKey, mockedSpendingKey) XCTAssertEqual(receivedMemo, memo) XCTAssertEqual(receivedShieldingThreshold, shieldingThreshold) return self.data.pendingTransactionEntity @@ -389,7 +394,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.shieldFunds(spendingKey: data.spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) { result in + synchronizer.shieldFunds(spendingKey: mockedSpendingKey, memo: memo, shieldingThreshold: shieldingThreshold) { result in switch result { case let .success(receivedEntity): XCTAssertEqual(receivedEntity.recipient, self.data.pendingTransactionEntity.recipient) @@ -402,9 +407,10 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testShieldFundsThrowsError() throws { + func testShieldFundsThrowsError() async throws { let memo: Memo = .text(try MemoText("Some message")) let shieldingThreshold = Zatoshi(1) + let mockedSpendingKey = await data.spendingKey synchronizerMock.shieldFundsSpendingKeyMemoShieldingThresholdClosure = { _, _, _ in throw "Some error" @@ -412,7 +418,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.shieldFunds(spendingKey: data.spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) { result in + synchronizer.shieldFunds(spendingKey: mockedSpendingKey, memo: memo, shieldingThreshold: shieldingThreshold) { result in switch result { case .success: XCTFail("Error should be thrown.") @@ -499,7 +505,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { func testGetMemosForClearedTransactionSucceed() throws { let memo: Memo = .text(try MemoText("Some message")) - synchronizerMock.getMemosForTransactionClosure = { receivedTransaction in + synchronizerMock.getMemosForClearedTransactionClosure = { receivedTransaction in XCTAssertEqual(receivedTransaction.id, self.data.clearedTransaction.id) return [memo] } @@ -521,7 +527,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testGetMemosForClearedTransactionThrowsError() { - synchronizerMock.getMemosForTransactionClosure = { _ in + synchronizerMock.getMemosForClearedTransactionClosure = { _ in throw "Some error" } @@ -664,7 +670,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testAllConfirmedTransactionsSucceed() throws { - synchronizerMock.allConfirmedTransactionsFromTransactionClosure = { receivedTransaction, limit in + synchronizerMock.allConfirmedTransactionsFromLimitClosure = { receivedTransaction, limit in XCTAssertEqual(receivedTransaction.id, self.data.clearedTransaction.id) XCTAssertEqual(limit, 3) return [self.data.clearedTransaction] @@ -687,7 +693,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testAllConfirmedTransactionsThrowsError() throws { - synchronizerMock.allConfirmedTransactionsFromTransactionClosure = { _, _ in + synchronizerMock.allConfirmedTransactionsFromLimitClosure = { _, _ in throw "Some error" } @@ -747,7 +753,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let skippedEntity = UnspentTransactionOutputEntityMock(address: "addr2", txid: Data(), index: 1, script: Data(), valueZat: 2, height: 3) let refreshedUTXO = (inserted: [insertedEntity], skipped: [skippedEntity]) - synchronizerMock.refreshUTXOsAddressFromHeightClosure = { receivedAddress, receivedFromHeight in + synchronizerMock.refreshUTXOsAddressFromClosure = { receivedAddress, receivedFromHeight in XCTAssertEqual(receivedAddress, self.data.transparentAddress) XCTAssertEqual(receivedFromHeight, 121000) return refreshedUTXO @@ -770,7 +776,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testRefreshUTXOsThrowsError() { - synchronizerMock.refreshUTXOsAddressFromHeightClosure = { _, _ in + synchronizerMock.refreshUTXOsAddressFromClosure = { _, _ in throw "Some error" } @@ -911,7 +917,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testRewindSucceed() { - synchronizerMock.rewindPolicyClosure = { receivedPolicy in + synchronizerMock.rewindClosure = { receivedPolicy in if case .quick = receivedPolicy { } else { XCTFail("Unexpected policy \(receivedPolicy)") @@ -940,7 +946,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testRewindThrowsError() { - synchronizerMock.rewindPolicyClosure = { _ in + synchronizerMock.rewindClosure = { _ in return Fail(error: "some error").eraseToAnyPublisher() } diff --git a/Tests/OfflineTests/CombineSynchronizerOfflineTests.swift b/Tests/OfflineTests/CombineSynchronizerOfflineTests.swift index faf8c7fd..0ade088e 100644 --- a/Tests/OfflineTests/CombineSynchronizerOfflineTests.swift +++ b/Tests/OfflineTests/CombineSynchronizerOfflineTests.swift @@ -12,7 +12,7 @@ import XCTest @testable import ZcashLightClientKit class CombineSynchronizerOfflineTests: XCTestCase { - var data: AlternativeSynchronizerAPITestsData! + var data: TestsData! var cancellables: [AnyCancellable] = [] var synchronizerMock: SynchronizerMock! @@ -20,7 +20,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { override func setUpWithError() throws { try super.setUpWithError() - data = AlternativeSynchronizerAPITestsData() + data = TestsData(networkType: .testnet) synchronizerMock = SynchronizerMock() synchronizer = CombineSDKSynchronizer(synchronizer: synchronizerMock) cancellables = [] @@ -112,17 +112,18 @@ class CombineSynchronizerOfflineTests: XCTestCase { XCTAssertEqual(synchronizer.connectionState, .reconnecting) } - func testPrepareSucceed() throws { - synchronizerMock.prepareWithSeedViewingKeysWalletBirthdayClosure = { receivedSeed, receivedViewingKeys, receivedWalletBirthday in + func testPrepareSucceed() async throws { + let mockedViewingKey = await self.data.viewingKey + synchronizerMock.prepareWithViewingKeysWalletBirthdayClosure = { receivedSeed, receivedViewingKeys, receivedWalletBirthday in XCTAssertEqual(receivedSeed, self.data.seed) - XCTAssertEqual(receivedViewingKeys, [self.data.viewingKey]) + XCTAssertEqual(receivedViewingKeys, [mockedViewingKey]) XCTAssertEqual(receivedWalletBirthday, self.data.birthday) return .success } let expectation = XCTestExpectation() - synchronizer.prepare(with: data.seed, viewingKeys: [data.viewingKey], walletBirthday: data.birthday) + synchronizer.prepare(with: data.seed, viewingKeys: [mockedViewingKey], walletBirthday: data.birthday) .sink( receiveCompletion: { result in switch result { @@ -141,14 +142,15 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testPrepareThrowsError() throws { - synchronizerMock.prepareWithSeedViewingKeysWalletBirthdayClosure = { _, _, _ in + func testPrepareThrowsError() async throws { + let mockedViewingKey = await self.data.viewingKey + synchronizerMock.prepareWithViewingKeysWalletBirthdayClosure = { _, _, _ in throw "Some error" } let expectation = XCTestExpectation() - synchronizer.prepare(with: data.seed, viewingKeys: [data.viewingKey], walletBirthday: data.birthday) + synchronizer.prepare(with: data.seed, viewingKeys: [mockedViewingKey], walletBirthday: data.birthday) .sink( receiveCompletion: { result in switch result { @@ -329,14 +331,15 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testSendToAddressSucceed() throws { + func testSendToAddressSucceed() async throws { let amount = Zatoshi(100) let recipient: Recipient = .transparent(data.transparentAddress) let memo: Memo = .text(try MemoText("Some message")) + let mockedSpendingKey = await data.spendingKey synchronizerMock .sendToAddressSpendingKeyZatoshiToAddressMemoClosure = { receivedSpendingKey, receivedZatoshi, receivedToAddress, receivedMemo in - XCTAssertEqual(receivedSpendingKey, self.data.spendingKey) + XCTAssertEqual(receivedSpendingKey, mockedSpendingKey) XCTAssertEqual(receivedZatoshi, amount) XCTAssertEqual(receivedToAddress, recipient) XCTAssertEqual(receivedMemo, memo) @@ -345,7 +348,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.sendToAddress(spendingKey: data.spendingKey, zatoshi: amount, toAddress: recipient, memo: memo) + synchronizer.sendToAddress(spendingKey: mockedSpendingKey, zatoshi: amount, toAddress: recipient, memo: memo) .sink( receiveCompletion: { result in switch result { @@ -364,10 +367,11 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testSendToAddressThrowsError() throws { + func testSendToAddressThrowsError() async throws { let amount = Zatoshi(100) let recipient: Recipient = .transparent(data.transparentAddress) let memo: Memo = .text(try MemoText("Some message")) + let mockedSpendingKey = await data.spendingKey synchronizerMock.sendToAddressSpendingKeyZatoshiToAddressMemoClosure = { _, _, _, _ in throw "Some error" @@ -375,7 +379,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.sendToAddress(spendingKey: data.spendingKey, zatoshi: amount, toAddress: recipient, memo: memo) + synchronizer.sendToAddress(spendingKey: mockedSpendingKey, zatoshi: amount, toAddress: recipient, memo: memo) .sink( receiveCompletion: { result in switch result { @@ -394,12 +398,13 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testShieldFundsSucceed() throws { + func testShieldFundsSucceed() async throws { let memo: Memo = .text(try MemoText("Some message")) let shieldingThreshold = Zatoshi(1) + let mockedSpendingKey = await data.spendingKey synchronizerMock.shieldFundsSpendingKeyMemoShieldingThresholdClosure = { receivedSpendingKey, receivedMemo, receivedShieldingThreshold in - XCTAssertEqual(receivedSpendingKey, self.data.spendingKey) + XCTAssertEqual(receivedSpendingKey, mockedSpendingKey) XCTAssertEqual(receivedMemo, memo) XCTAssertEqual(receivedShieldingThreshold, shieldingThreshold) return self.data.pendingTransactionEntity @@ -407,7 +412,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.shieldFunds(spendingKey: data.spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) + synchronizer.shieldFunds(spendingKey: mockedSpendingKey, memo: memo, shieldingThreshold: shieldingThreshold) .sink( receiveCompletion: { result in switch result { @@ -426,9 +431,10 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testShieldFundsThrowsError() throws { + func testShieldFundsThrowsError() async throws { let memo: Memo = .text(try MemoText("Some message")) let shieldingThreshold = Zatoshi(1) + let mockedSpendingKey = await data.spendingKey synchronizerMock.shieldFundsSpendingKeyMemoShieldingThresholdClosure = { _, _, _ in throw "Some error" @@ -436,7 +442,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.shieldFunds(spendingKey: data.spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) + synchronizer.shieldFunds(spendingKey: mockedSpendingKey, memo: memo, shieldingThreshold: shieldingThreshold) .sink( receiveCompletion: { result in switch result { @@ -581,7 +587,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { func testGetMemosForClearedTransactionSucceed() throws { let memo: Memo = .text(try MemoText("Some message")) - synchronizerMock.getMemosForTransactionClosure = { receivedTransaction in + synchronizerMock.getMemosForClearedTransactionClosure = { receivedTransaction in XCTAssertEqual(receivedTransaction.id, self.data.clearedTransaction.id) return [memo] } @@ -608,7 +614,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testGetMemosForClearedTransactionThrowsError() { - synchronizerMock.getMemosForTransactionClosure = { _ in + synchronizerMock.getMemosForClearedTransactionClosure = { _ in throw "Some error" } @@ -802,7 +808,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testAllConfirmedTransactionsSucceed() throws { - synchronizerMock.allConfirmedTransactionsFromTransactionClosure = { receivedTransaction, limit in + synchronizerMock.allConfirmedTransactionsFromLimitClosure = { receivedTransaction, limit in XCTAssertEqual(receivedTransaction.id, self.data.clearedTransaction.id) XCTAssertEqual(limit, 3) return [self.data.clearedTransaction] @@ -830,7 +836,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testAllConfirmedTransactionsThrowsError() throws { - synchronizerMock.allConfirmedTransactionsFromTransactionClosure = { _, _ in + synchronizerMock.allConfirmedTransactionsFromLimitClosure = { _, _ in throw "Some error" } @@ -910,7 +916,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let skippedEntity = UnspentTransactionOutputEntityMock(address: "addr2", txid: Data(), index: 1, script: Data(), valueZat: 2, height: 3) let refreshedUTXO = (inserted: [insertedEntity], skipped: [skippedEntity]) - synchronizerMock.refreshUTXOsAddressFromHeightClosure = { receivedAddress, receivedFromHeight in + synchronizerMock.refreshUTXOsAddressFromClosure = { receivedAddress, receivedFromHeight in XCTAssertEqual(receivedAddress, self.data.transparentAddress) XCTAssertEqual(receivedFromHeight, 121000) return refreshedUTXO @@ -939,7 +945,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testRefreshUTXOsThrowsError() { - synchronizerMock.refreshUTXOsAddressFromHeightClosure = { _, _ in + synchronizerMock.refreshUTXOsAddressFromClosure = { _, _ in throw "Some error" } @@ -1126,7 +1132,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testRewindSucceed() { - synchronizerMock.rewindPolicyClosure = { receivedPolicy in + synchronizerMock.rewindClosure = { receivedPolicy in if case .quick = receivedPolicy { } else { XCTFail("Unexpected policy \(receivedPolicy)") @@ -1155,7 +1161,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testRewindThrowsError() { - synchronizerMock.rewindPolicyClosure = { _ in + synchronizerMock.rewindClosure = { _ in return Fail(error: "some error").eraseToAnyPublisher() } diff --git a/Tests/OfflineTests/CompactBlockProcessorOfflineTests.swift b/Tests/OfflineTests/CompactBlockProcessorOfflineTests.swift index 32b3ce9e..9e6a71b8 100644 --- a/Tests/OfflineTests/CompactBlockProcessorOfflineTests.swift +++ b/Tests/OfflineTests/CompactBlockProcessorOfflineTests.swift @@ -11,24 +11,22 @@ import XCTest class CompactBlockProcessorOfflineTests: XCTestCase { let testFileManager = FileManager() - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) } override func tearDownWithError() throws { try super.tearDownWithError() - try FileManager.default.removeItem(at: self.testTempDirectory) + try FileManager.default.removeItem(at: testTempDirectory) } func testComputeProcessingRangeForSingleLoop() async throws { let network = ZcashNetworkBuilder.network(for: .testnet) - let realRustBackend = ZcashRustBackend.self + let rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .testnet) let processorConfig = CompactBlockProcessor.Configuration.standard( for: network, @@ -44,7 +42,7 @@ class CompactBlockProcessorOfflineTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -55,7 +53,7 @@ class CompactBlockProcessorOfflineTests: XCTestCase { let processor = CompactBlockProcessor( service: service, storage: storage, - backend: ZcashRustBackend.self, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger diff --git a/Tests/OfflineTests/CompactBlockRepositoryTests.swift b/Tests/OfflineTests/CompactBlockRepositoryTests.swift index 28334f35..d65011a7 100644 --- a/Tests/OfflineTests/CompactBlockRepositoryTests.swift +++ b/Tests/OfflineTests/CompactBlockRepositoryTests.swift @@ -13,22 +13,22 @@ import XCTest class CompactBlockRepositoryTests: XCTestCase { let network = ZcashNetworkBuilder.network(for: .testnet) - - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .testnet) } override func tearDownWithError() throws { try super.tearDownWithError() try? testFileManager.removeItem(at: testTempDirectory) + rustBackend = nil + testTempDirectory = nil } func testEmptyStorage() async throws { @@ -36,7 +36,7 @@ class CompactBlockRepositoryTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -55,7 +55,7 @@ class CompactBlockRepositoryTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -82,7 +82,7 @@ class CompactBlockRepositoryTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -108,7 +108,7 @@ class CompactBlockRepositoryTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, diff --git a/Tests/OfflineTests/DerivationToolTests/DerivationToolMainnetTests.swift b/Tests/OfflineTests/DerivationToolTests/DerivationToolMainnetTests.swift index a5107360..433e6151 100644 --- a/Tests/OfflineTests/DerivationToolTests/DerivationToolMainnetTests.swift +++ b/Tests/OfflineTests/DerivationToolTests/DerivationToolMainnetTests.swift @@ -16,7 +16,8 @@ class DerivationToolMainnetTests: XCTestCase { let testRecipientAddress = UnifiedAddress( validatedEncoding: """ u1l9f0l4348negsncgr9pxd9d3qaxagmqv3lnexcplmufpq7muffvfaue6ksevfvd7wrz7xrvn95rc5zjtn7ugkmgh5rnxswmcj30y0pw52pn0zjvy38rn2esfgve64rj5pcmazxgpyuj - """ + """, + networkType: .mainnet ) let expectedSpendingKey = UnifiedSpendingKey( @@ -45,82 +46,72 @@ class DerivationToolMainnetTests: XCTestCase { """) let expectedSaplingAddress = SaplingAddress(validatedEncoding: "zs1vp7kvlqr4n9gpehztr76lcn6skkss9p8keqs3nv8avkdtjrcctrvmk9a7u494kluv756jeee5k0") - - let derivationTool = DerivationTool(networkType: NetworkType.mainnet) + + let derivationTool = TestsData(networkType: .mainnet).derivationTools let expectedTransparentAddress = TransparentAddress(validatedEncoding: "t1dRJRY7GmyeykJnMH38mdQoaZtFhn1QmGz") - func testDeriveViewingKeysFromSeed() throws { + + func testDeriveViewingKeysFromSeed() async throws { let seedBytes = [UInt8](seedData) - let spendingKey = try derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: 0) - - let viewingKey = try derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: 0) + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) XCTAssertEqual(expectedViewingKey, viewingKey) } - func testDeriveViewingKeyFromSpendingKeys() throws { - XCTAssertEqual( - expectedViewingKey, - try derivationTool.deriveUnifiedFullViewingKey(from: expectedSpendingKey) - ) + func testDeriveViewingKeyFromSpendingKeys() async throws { + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: expectedSpendingKey) + XCTAssertEqual(expectedViewingKey, viewingKey) } - func testDeriveSpendingKeysFromSeed() throws { + func testDeriveSpendingKeysFromSeed() async throws { let seedBytes = [UInt8](seedData) - let spendingKey = try derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: 0) + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: 0) XCTAssertEqual(expectedSpendingKey, spendingKey) } - func testDeriveUnifiedSpendingKeyFromSeed() throws { + func testDeriveUnifiedSpendingKeyFromSeed() async throws { let account = 0 let seedBytes = [UInt8](seedData) - XCTAssertNoThrow(try derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: account)) + _ = try await derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: account) } func testGetTransparentAddressFromUA() throws { XCTAssertEqual( - try DerivationTool.transparentReceiver(from: testRecipientAddress), + try DerivationTool(networkType: .mainnet).transparentReceiver(from: testRecipientAddress), expectedTransparentAddress ) } func testIsValidViewingKey() { XCTAssertTrue( - DerivationTool.rustwelding.isValidSaplingExtendedFullViewingKey( + ZcashKeyDerivationBackend(networkType: .mainnet).isValidSaplingExtendedFullViewingKey( """ zxviews1q0dm7hkzqqqqpqplzv3f50rl4vay8uy5zg9e92f62lqg6gzu63rljety32xy5tcyenzuu3n386ws772nm6tp4sads8n37gff6nxmyz8dn9keehmapk0spc6pzx5ux\ epgu52xnwzxxnuja5tv465t9asppnj3eqncu3s7g3gzg5x8ss4ypkw08xwwyj7ky5skvnd9ldwj2u8fz2ry94s5q8p9lyp3j96yckudmp087d2jr2rnfuvjp7f56v78vpe658\ vljjddj7s645q399jd7 - """, - networkType: .mainnet + """ ) ) XCTAssertFalse( - DerivationTool.rustwelding.isValidSaplingExtendedFullViewingKey( - "zxviews1q0dm7hkzky5skvnd9ldwj2u8fz2ry94s5q8p9lyp3j96yckudmp087d2jr2rnfuvjp7f56v78vpe658vljjddj7s645q399jd7", - networkType: .mainnet + ZcashKeyDerivationBackend(networkType: .mainnet).isValidSaplingExtendedFullViewingKey( + "zxviews1q0dm7hkzky5skvnd9ldwj2u8fz2ry94s5q8p9lyp3j96yckudmp087d2jr2rnfuvjp7f56v78vpe658vljjddj7s645q399jd7" ) ) } - func testDeriveQuiteALotOfUnifiedKeysFromSeed() throws { + func testDeriveQuiteALotOfUnifiedKeysFromSeed() async throws { let numberOfAccounts: Int = 10 - let ufvks = try (0 ..< numberOfAccounts) - .map({ - try derivationTool.deriveUnifiedSpendingKey( - seed: [UInt8](seedData), - accountIndex: $0 - ) - }) - .map { - try derivationTool.deriveUnifiedFullViewingKey( - from: $0 - ) - } + var ufvks: [UnifiedFullViewingKey] = [] + for i in 0..(typecodes), @@ -70,7 +70,8 @@ final class UnifiedTypecodesTests: XCTestCase { validatedEncoding: """ u1l9f0l4348negsncgr9pxd9d3qaxagmqv3lnexcplmufpq7muffvfaue6ksevfvd7wrz7xrvn95rc5zjtn7ugkmgh5rnxswmcj30y0pw52pn0zjvy38rn2esfgve64rj5pcmazxg\ pyuj - """ + """, + networkType: .testnet ) XCTAssertEqual(try ua.availableReceiverTypecodes(), [.sapling, .p2pkh]) diff --git a/Tests/OfflineTests/WalletTests.swift b/Tests/OfflineTests/WalletTests.swift index 8625378f..c204a433 100644 --- a/Tests/OfflineTests/WalletTests.swift +++ b/Tests/OfflineTests/WalletTests.swift @@ -12,13 +12,8 @@ import XCTest @testable import ZcashLightClientKit class WalletTests: XCTestCase { - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() - + var testTempDirectory: URL! var dbData: URL! = nil var paramDestination: URL! = nil var network = ZcashNetworkBuilder.network(for: .testnet) @@ -26,8 +21,9 @@ class WalletTests: XCTestCase { override func setUpWithError() throws { try super.setUpWithError() + testTempDirectory = Environment.uniqueTestTempDirectory dbData = try __dataDbURL() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) paramDestination = try __documentsDirectory().appendingPathComponent("parameters") } @@ -36,17 +32,17 @@ class WalletTests: XCTestCase { if testFileManager.fileExists(atPath: dbData.absoluteString) { try testFileManager.trashItem(at: dbData, resultingItemURL: nil) } - try? self.testFileManager.removeItem(at: self.testTempDirectory) + try? self.testFileManager.removeItem(at: testTempDirectory) } func testWalletInitialization() async throws { - let derivationTool = DerivationTool(networkType: network.networkType) - let ufvk = try derivationTool.deriveUnifiedSpendingKey(seed: seedData.bytes, accountIndex: 0) - .map({ try derivationTool.deriveUnifiedFullViewingKey(from: $0) }) + let derivationTool = TestsData(networkType: network.networkType).derivationTools + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: seedData.bytes, accountIndex: 0) + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) let wallet = Initializer( cacheDbURL: nil, - fsBlockDbRoot: self.testTempDirectory, + fsBlockDbRoot: testTempDirectory, dataDbURL: try __dataDbURL(), pendingDbURL: try TestDbBuilder.pendingTransactionsDbURL(), endpoint: LightWalletEndpointBuilder.default, @@ -58,7 +54,7 @@ class WalletTests: XCTestCase { let synchronizer = SDKSynchronizer(initializer: wallet) do { - guard case .success = try await synchronizer.prepare(with: seedData.bytes, viewingKeys: [ufvk], walletBirthday: 663194) else { + guard case .success = try await synchronizer.prepare(with: seedData.bytes, viewingKeys: [viewingKey], walletBirthday: 663194) else { XCTFail("Failed to initDataDb. Expected `.success` got: `.seedRequired`") return } diff --git a/Tests/OfflineTests/ZcashRustBackendTests.swift b/Tests/OfflineTests/ZcashRustBackendTests.swift index a2deca94..1304459c 100644 --- a/Tests/OfflineTests/ZcashRustBackendTests.swift +++ b/Tests/OfflineTests/ZcashRustBackendTests.swift @@ -12,8 +12,8 @@ import XCTest class ZcashRustBackendTests: XCTestCase { var dbData: URL! + var rustBackend: ZcashRustBackendWelding! var dataDbHandle = TestDbHandle(originalDb: TestDbBuilder.prePopulatedDataDbURL()!) - let spendingKey = """ secret-extended-key-test1qvpevftsqqqqpqy52ut2vv24a2qh7nsukew7qg9pq6djfwyc3xt5vaxuenshp2hhspp9qmqvdh0gs2ljpwxders5jkwgyhgln0drjqaguaenfhehz4esdl4k\ wlm5t9q0l6wmzcrvcf5ed6dqzvct3e2ge7f6qdvzhp02m7sp5a0qjssrwpdh7u6tq89hl3wchuq8ljq8r8rwd6xdwh3nry9at80z7amnj3s6ah4jevnvfr08gxpws523z95g6dmn4wm6l3658\ @@ -24,23 +24,27 @@ class ZcashRustBackendTests: XCTestCase { let zpend: Int = 500_000 let networkType = NetworkType.testnet - + override func setUp() { super.setUp() + dbData = try! __dataDbURL() try? dataDbHandle.setUp() + + rustBackend = ZcashRustBackend.makeForTests(dbData: dbData, fsBlockDbRoot: Environment.uniqueTestTempDirectory, networkType: .testnet) } override func tearDown() { super.tearDown() try? FileManager.default.removeItem(at: dbData!) dataDbHandle.dispose() + rustBackend = nil } func testInitWithShortSeedAndFail() async throws { let seed = "testreferencealice" - let dbInit = try await ZcashRustBackend.initDataDb(dbData: self.dbData!, seed: nil, networkType: self.networkType) + let dbInit = try await rustBackend.initDataDb(seed: nil) guard case .success = dbInit else { XCTFail("Failed to initDataDb. Expected `.success` got: \(String(describing: dbInit))") @@ -48,76 +52,64 @@ class ZcashRustBackendTests: XCTestCase { } do { - _ = try await ZcashRustBackend.createAccount(dbData: dbData!, seed: Array(seed.utf8), networkType: networkType) + _ = try await rustBackend.createAccount(seed: Array(seed.utf8)) XCTFail("createAccount should fail here.") } catch { } } func testIsValidTransparentAddressFalse() { XCTAssertFalse( - ZcashRustBackend.isValidTransparentAddress( - "ztestsapling12k9m98wmpjts2m56wc60qzhgsfvlpxcwah268xk5yz4h942sd58jy3jamqyxjwums6hw7kfa4cc", - networkType: networkType + ZcashKeyDerivationBackend(networkType: networkType).isValidTransparentAddress( + "ztestsapling12k9m98wmpjts2m56wc60qzhgsfvlpxcwah268xk5yz4h942sd58jy3jamqyxjwums6hw7kfa4cc" ) ) } func testIsValidTransparentAddressTrue() { XCTAssertTrue( - ZcashRustBackend.isValidTransparentAddress( - "tmSwpioc7reeoNrYB9SKpWkurJz3yEj3ee7", - networkType: networkType + ZcashKeyDerivationBackend(networkType: networkType).isValidTransparentAddress( + "tmSwpioc7reeoNrYB9SKpWkurJz3yEj3ee7" ) ) } func testIsValidSaplingAddressTrue() { XCTAssertTrue( - ZcashRustBackend.isValidSaplingAddress( - "ztestsapling12k9m98wmpjts2m56wc60qzhgsfvlpxcwah268xk5yz4h942sd58jy3jamqyxjwums6hw7kfa4cc", - networkType: networkType + ZcashKeyDerivationBackend(networkType: networkType).isValidSaplingAddress( + "ztestsapling12k9m98wmpjts2m56wc60qzhgsfvlpxcwah268xk5yz4h942sd58jy3jamqyxjwums6hw7kfa4cc" ) ) } func testIsValidSaplingAddressFalse() { XCTAssertFalse( - ZcashRustBackend.isValidSaplingAddress( - "tmSwpioc7reeoNrYB9SKpWkurJz3yEj3ee7", - networkType: networkType + ZcashKeyDerivationBackend(networkType: networkType).isValidSaplingAddress( + "tmSwpioc7reeoNrYB9SKpWkurJz3yEj3ee7" ) ) } func testListTransparentReceivers() async throws { let testVector = [TestVector](TestVector.testVectors![0 ... 2]) - let network = NetworkType.mainnet let tempDBs = TemporaryDbBuilder.build() let seed = testVector[0].root_seed! + rustBackend = ZcashRustBackend.makeForTests(dbData: tempDBs.dataDB, fsBlockDbRoot: Environment.uniqueTestTempDirectory, networkType: .mainnet) try? FileManager.default.removeItem(at: tempDBs.dataDB) - let initResult = try await ZcashRustBackend.initDataDb( - dbData: tempDBs.dataDB, - seed: seed, - networkType: network - ) + let initResult = try await rustBackend.initDataDb(seed: seed) XCTAssertEqual(initResult, .success) - let usk = try await ZcashRustBackend.createAccount( - dbData: tempDBs.dataDB, - seed: seed, - networkType: network - ) + let usk = try await rustBackend.createAccount(seed: seed) XCTAssertEqual(usk.account, 0) let expectedReceivers = try testVector.map { - UnifiedAddress(validatedEncoding: $0.unified_addr!) + UnifiedAddress(validatedEncoding: $0.unified_addr!, networkType: .mainnet) } .map { try $0.transparentReceiver() } let expectedUAs = testVector.map { - UnifiedAddress(validatedEncoding: $0.unified_addr!) + UnifiedAddress(validatedEncoding: $0.unified_addr!, networkType: .mainnet) } guard expectedReceivers.count >= 2 else { @@ -127,19 +119,11 @@ class ZcashRustBackendTests: XCTestCase { var uAddresses: [UnifiedAddress] = [] for i in 0...2 { uAddresses.append( - try await ZcashRustBackend.getCurrentAddress( - dbData: tempDBs.dataDB, - account: 0, - networkType: network - ) + try await rustBackend.getCurrentAddress(account: 0) ) if i < 2 { - _ = try await ZcashRustBackend.getNextAvailableAddress( - dbData: tempDBs.dataDB, - account: 0, - networkType: network - ) + _ = try await rustBackend.getNextAvailableAddress(account: 0) } } @@ -148,11 +132,7 @@ class ZcashRustBackendTests: XCTestCase { expectedUAs ) - let actualReceivers = try await ZcashRustBackend.listTransparentReceivers( - dbData: tempDBs.dataDB, - account: 0, - networkType: network - ) + let actualReceivers = try await rustBackend.listTransparentReceivers(account: 0) XCTAssertEqual( expectedReceivers.sorted(), @@ -163,7 +143,7 @@ class ZcashRustBackendTests: XCTestCase { func testGetMetadataFromAddress() throws { let recipientAddress = "zs17mg40levjezevuhdp5pqrd52zere7r7vrjgdwn5sj4xsqtm20euwahv9anxmwr3y3kmwuz8k55a" - let metadata = ZcashRustBackend.getAddressMetadata(recipientAddress) + let metadata = ZcashKeyDerivationBackend.getAddressMetadata(recipientAddress) XCTAssertEqual(metadata?.networkType, .mainnet) XCTAssertEqual(metadata?.addressType, .sapling) diff --git a/Tests/PerformanceTests/SynchronizerTests.swift b/Tests/PerformanceTests/SynchronizerTests.swift index abbdc189..c659828d 100644 --- a/Tests/PerformanceTests/SynchronizerTests.swift +++ b/Tests/PerformanceTests/SynchronizerTests.swift @@ -26,6 +26,8 @@ class SynchronizerTests: XCTestCase { var coordinator: TestCoordinator! var cancellables: [AnyCancellable] = [] var sdkSynchronizerSyncStatusHandler: SDKSynchronizerSyncStatusHandler! = SDKSynchronizerSyncStatusHandler() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! let seedPhrase = """ wish puppy smile loan doll curve hole maze file ginger hair nose key relax knife witness cannon grab despair throw review deal slush frame @@ -33,11 +35,19 @@ class SynchronizerTests: XCTestCase { var birthday: BlockHeight = 1_730_000 + override func setUp() async throws { + try await super.setUp() + testTempDirectory = Environment.uniqueTestTempDirectory + rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .mainnet) + } + override func tearDown() { super.tearDown() coordinator = nil cancellables = [] sdkSynchronizerSyncStatusHandler = nil + rustBackend = nil + testTempDirectory = nil } func testHundredBlocksSync() async throws { @@ -47,11 +57,11 @@ class SynchronizerTests: XCTestCase { return } let seedBytes = [UInt8](seedData) - let spendingKey = try derivationTool.deriveUnifiedSpendingKey( + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey( seed: seedBytes, accountIndex: 0 ) - let ufvk = try derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) + let ufvk = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) let network = ZcashNetworkBuilder.network(for: .mainnet) let endpoint = LightWalletEndpoint(address: "lightwalletd.electriccoin.co", port: 9067, secure: true) diff --git a/Tests/TestUtils/CompactBlockProcessorEventHandler.swift b/Tests/TestUtils/CompactBlockProcessorEventHandler.swift index 24644fc0..3e1a5920 100644 --- a/Tests/TestUtils/CompactBlockProcessorEventHandler.swift +++ b/Tests/TestUtils/CompactBlockProcessorEventHandler.swift @@ -29,6 +29,7 @@ class CompactBlockProcessorEventHandler { func subscribe(to blockProcessor: CompactBlockProcessor, expectations: [EventIdentifier: XCTestExpectation]) async { let closure: CompactBlockProcessor.EventClosure = { event in + print("Received event: \(event.identifier)") expectations[event.identifier]?.fulfill() } diff --git a/Tests/TestUtils/Sourcery/AutoMockable.stencil b/Tests/TestUtils/Sourcery/AutoMockable.stencil new file mode 100644 index 00000000..f14be10b --- /dev/null +++ b/Tests/TestUtils/Sourcery/AutoMockable.stencil @@ -0,0 +1,165 @@ +import Combine +@testable import ZcashLightClientKit + +{% macro methodName method%}{%if method|annotated:"mockedName" %}{{ method.annotations.mockedName }}{% else %}{% call swiftifyMethodName method.selectorName %}{% endif %}{% endmacro %} +{% macro swiftifyMethodName name %}{{ name | replace:"(","_" | replace:")","" | replace:":","_" | replace:"`","" | snakeToCamelCase | lowerFirstWord }}{% endmacro %} +{% macro methodNameUpper method%}{%if method|annotated:"mockedName" %}{{ method.annotations.mockedName }}{% else %}{% call swiftifyMethodNameUpper method.selectorName %}{% endif %}{% endmacro %} +{% macro swiftifyMethodNameUpper name %}{{ name | replace:"(","_" | replace:")","" | replace:":","_" | replace:"`","" | snakeToCamelCase | upperFirstLetter }}{% endmacro %} +{% macro methodThrowableErrorDeclaration method type %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}ThrowableError: Error? + {% call methodMockPropertySetter method type "ThrowableError" "Error?" %} +{% endmacro %} +{% macro methodMockPropertySetter method type postfix propertyType %} + {% if type|annotated:"mockActor" %}{% if not method.isStatic %} + func set{% call methodNameUpper method %}{{ postfix }}(_ param: {{ propertyType }}) async { + {% call methodName method %}{{ postfix }} = param + } + {% endif %}{% endif %} +{% endmacro %} +{% macro methodThrowableErrorUsage method %} + if let error = {% if method.isStatic %}Self.{% endif %}{% call methodName method %}ThrowableError { + throw error + } +{% endmacro %} +{% macro methodReceivedParameters method %} + {%if method.parameters.count == 1 %} + {% if method.isStatic %}Self.{% endif %}{% call methodName method %}Received{% for param in method.parameters %}{{ param.name|upperFirstLetter }} = {{ param.name }}{% endfor %} + {% else %} + {% if not method.parameters.count == 0 %} + {% if method.isStatic %}Self.{% endif %}{% call methodName method %}ReceivedArguments = ({% for param in method.parameters %}{{ param.name }}: {{ param.name }}{% if not forloop.last%}, {% endif %}{% endfor %}) + {% endif %} + {% endif %} +{% endmacro %} +{% macro methodClosureName method %}{% call methodName method %}Closure{% endmacro %} +{% macro paramTypeName param, method %}{% if method.annotations[param.name] %}{{method.annotations[param.name]}}{% else %}{{ param.typeName }}{% endif %}{% endmacro %} +{% macro unwrappedParamTypeName param, method %}{% if method.annotations[param.name] %}{{method.annotations[param.name]}}{% else %}{{ param.typeName.unwrappedTypeName }}{% endif %}{% endmacro %} +{% macro closureType method type %}({% for param in method.parameters %}{% call paramTypeName param, method %}{% if not forloop.last %}, {% endif %}{% endfor %}) {% if method.isAsync %}async {% endif %}{% if method.throws %}throws {% endif %}-> {% if method.isInitializer %}Void{% else %}{{ method.returnTypeName }}{% endif %}{% endmacro %} +{% macro methodClosureDeclaration method type %} + {% if method.isStatic %}static {% endif %}var {% call methodClosureName method %}: ({% call closureType method type %})? + {% if type|annotated:"mockActor" %}{% if not method.isStatic %} + func set{% call methodNameUpper method %}Closure(_ param: ({% call closureType method type %})?) async { + {% call methodName method %}Closure = param + } + {% endif %}{% endif %} +{% endmacro %} +{% macro methodClosureCallParameters method %}{% for param in method.parameters %}{{ param.name }}{% if not forloop.last %}, {% endif %}{% endfor %}{% endmacro %} +{% macro mockMethod method type %} +{% if method|!annotated:"skipAutoMock" %} + // MARK: - {{ method.shortName }} + + {% if ((type|annotated:"mockActor") and (method.isAsync) or (method.isStatic)) or (not type|annotated:"mockActor") %} + {% if method.throws %} + {% call methodThrowableErrorDeclaration method type %} + {% endif %} + {% if not method.isInitializer %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}CallsCount = 0 + {% if method.isStatic %}static {% endif %}var {% call methodName method %}Called: Bool { + return {% if method.isStatic %}Self.{% endif %}{% call methodName method %}CallsCount > 0 + } + {% endif %} + {% if method.parameters.count == 1 %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}Received{% for param in method.parameters %}{{ param.name|upperFirstLetter }}: {% if param.isClosure %}({% endif %}{% call unwrappedParamTypeName param, method %}{% if param.isClosure %}){% endif %}?{% endfor %} + {% else %}{% if not method.parameters.count == 0 %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}ReceivedArguments: ({% for param in method.parameters %}{{ param.name }}: {% if param.typeAttributes.escaping %}{% call unwrappedParamTypeName param, method %}{% else %}{% call paramTypeName param, method %}{% endif %}{% if not forloop.last %}, {% endif %}{% endfor %})? + {% endif %}{% endif %} + {% if not method.returnTypeName.isVoid and not method.isInitializer %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}ReturnValue: {{ method.returnTypeName }}{{ '!' if not method.isOptionalReturnType }} + {% call methodMockPropertySetter method type "ReturnValue" method.returnTypeName %} + {% endif %} + {% call methodClosureDeclaration method type %} + {% endif %} + +{% if method.isInitializer %} + required {{ method.name }} { + {% call methodReceivedParameters method %} + {% call methodClosureName method %}?({% call methodClosureCallParameters method %}) + } +{% else %} + {% if (not method.isAsync) and (not method.isStatic) and (type|annotated:"mockActor") %}nonisolated {% endif %}{% if method.isStatic %}static {% endif %}func {{ method.name }}{% if method.isAsync %} async{% endif %}{% if method.throws %} throws{% endif %}{% if not method.returnTypeName.isVoid %} -> {{ method.returnTypeName }}{% endif %} { + {% if ((type|annotated:"mockActor") and ((method.isAsync) or (method.isStatic))) or (not type|annotated:"mockActor") %} + {% if method.throws %} + {% call methodThrowableErrorUsage method %} + {% endif %} + {% if method.isStatic %}Self.{% endif %}{% call methodName method %}CallsCount += 1 + {% call methodReceivedParameters method %} + {% if method.returnTypeName.isVoid %} + {% if method.throws %}try {% endif %}{% if method.isAsync %}await {% endif %}{% call methodClosureName method %}?({% call methodClosureCallParameters method %}) + {% else %} + if let closure = {% if method.isStatic %}Self.{% endif %}{% call methodClosureName method %} { + return {% if method.throws %}try {% endif %}{% if method.isAsync %}await {% endif %}closure({% call methodClosureCallParameters method %}) + } else { + return {% if method.isStatic %}Self.{% endif %}{% call methodName method %}ReturnValue + } + {% endif %} + {% else %} + {% if method.throws %}try {% endif %}{% call methodClosureName method %}!({% call methodClosureCallParameters method %}) + {% endif %} + } + +{% endif %} +{% endif %} +{% endmacro %} +{% macro mockOptionalVariable variable %} + var {% call mockedVariableName variable %}: {{ variable.typeName }} +{% endmacro %} +{% macro mockNonOptionalArrayOrDictionaryVariable variable %} + var {% call mockedVariableName variable %}: {{ variable.typeName }} { + get{% if variable.isAsync %} async{% endif %} { return {% call underlyingMockedVariableName variable %} } + } + var {% call underlyingMockedVariableName variable %}: {{ variable.typeName }} = {% if variable.isArray %}[]{% elif variable.isDictionary %}[:]{% endif %} +{% endmacro %} +{% macro mockNonOptionalVariable variable %} + var {% call mockedVariableName variable %}: {{ variable.typeName }} { + get { return {% call underlyingMockedVariableName variable %} } + } + var {% call underlyingMockedVariableName variable %}: {% if variable.typeName.isClosure %}({{ variable.typeName }})!{% else %}{{ variable.typeName }}!{% endif %} +{% endmacro %} + +{% macro underlyingMockedVariableName variable %}underlying{{ variable.name|upperFirstLetter }}{% endmacro %} +{% macro initialMockedVariableValue variable %}initial{{ variable.name|upperFirstLetter }}{% endmacro %} +{% macro mockedVariableName variable %}{{ variable.name }}{% endmacro %} + +// MARK: - AutoMockable protocols +{% for type in types.protocols where type.based.AutoMockable or type|annotated:"AutoMockable" %}{% if type.name != "AutoMockable" %} +{% if type|annotated:"moduleName" %} +/// Imported from {{ type.annotations.moduleName }} module +{% endif %} +{% if type|annotated:"targetOS" %} +#if os({{ type.annotations.targetOS }}) +{% endif %} +{% if type|annotated:"mockActor" %}actor {% else %}class {% endif %}{{ type.name }}Mock: {% if type|annotated:"baseClass" %}{{ type.annotations.baseClass }}, {% endif %}{% if type|annotated:"moduleName" %}{{ type.annotations.moduleName }}.{% endif %}{{ type.name }} { + +{% for method in type.allMethods|!definedInExtension %} + {% if (not method.isAsync) and (not method.isStatic) and (type|annotated:"mockActor") %} + nonisolated let {% call methodName method %}Closure: ({% call closureType method type %})? + {% endif %} +{% endfor %} + + init( +{% for method in type.allMethods|!definedInExtension where ((not method.isAsync) and (not method.isStatic) and (type|annotated:"mockActor")) %} + {% call methodName method %}Closure: ({% call closureType method type %})? = nil{% if not forloop.last %},{% endif %} +{% endfor %} + ) { + {% for method in type.allMethods|!definedInExtension %} + {% if (not method.isAsync) and (not method.isStatic) and (type|annotated:"mockActor") %} + self.{% call methodName method %}Closure = {% call methodName method %}Closure + {% endif %} + {% endfor %} + } +{% for variable in type.allVariables|!definedInExtension %} +{% if variable|!annotated:"skipAutoMock" %} + {% if variable.isOptional %}{% call mockOptionalVariable variable %} + {% elif variable.isArray or variable.isDictionary %}{% call mockNonOptionalArrayOrDictionaryVariable variable %} + {% else %}{% call mockNonOptionalVariable variable %} + {% endif %} +{% endif %} +{% endfor %} + +{% for method in type.allMethods|!definedInExtension %} + {% call mockMethod method type %} +{% endfor %} +} +{% if type|annotated:"targetOS" %} +#endif +{% endif %} +{% endif %}{% endfor %} diff --git a/Tests/TestUtils/Sourcery/AutoMockable.swift b/Tests/TestUtils/Sourcery/AutoMockable.swift new file mode 100644 index 00000000..c1b6e08d --- /dev/null +++ b/Tests/TestUtils/Sourcery/AutoMockable.swift @@ -0,0 +1,18 @@ +// +// AutoMockable.swift +// +// +// Created by Michal Fousek on 04.04.2023. +// + +/// This file defines types for which we need to generate mocks for usage in Player tests. +/// Each type must appear in appropriate section according to which module it comes from. + +// sourcery:begin: AutoMockable + +@testable import ZcashLightClientKit + +extension ZcashRustBackendWelding { } +extension Synchronizer { } + +// sourcery:end: diff --git a/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift b/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift new file mode 100644 index 00000000..60f60161 --- /dev/null +++ b/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift @@ -0,0 +1,1302 @@ +// Generated using Sourcery 2.0.2 — https://github.com/krzysztofzablocki/Sourcery +// DO NOT EDIT +import Combine +@testable import ZcashLightClientKit + + + +// MARK: - AutoMockable protocols +class SynchronizerMock: Synchronizer { + + + init( + ) { + } + var alias: ZcashSynchronizerAlias { + get { return underlyingAlias } + } + var underlyingAlias: ZcashSynchronizerAlias! + var latestState: SynchronizerState { + get { return underlyingLatestState } + } + var underlyingLatestState: SynchronizerState! + var connectionState: ConnectionState { + get { return underlyingConnectionState } + } + var underlyingConnectionState: ConnectionState! + var stateStream: AnyPublisher { + get { return underlyingStateStream } + } + var underlyingStateStream: AnyPublisher! + var eventStream: AnyPublisher { + get { return underlyingEventStream } + } + var underlyingEventStream: AnyPublisher! + var metrics: SDKMetrics { + get { return underlyingMetrics } + } + var underlyingMetrics: SDKMetrics! + var pendingTransactions: [PendingTransactionEntity] { + get async { return underlyingPendingTransactions } + } + var underlyingPendingTransactions: [PendingTransactionEntity] = [] + var clearedTransactions: [ZcashTransaction.Overview] { + get async { return underlyingClearedTransactions } + } + var underlyingClearedTransactions: [ZcashTransaction.Overview] = [] + var sentTransactions: [ZcashTransaction.Sent] { + get async { return underlyingSentTransactions } + } + var underlyingSentTransactions: [ZcashTransaction.Sent] = [] + var receivedTransactions: [ZcashTransaction.Received] { + get async { return underlyingReceivedTransactions } + } + var underlyingReceivedTransactions: [ZcashTransaction.Received] = [] + + // MARK: - prepare + + var prepareWithViewingKeysWalletBirthdayThrowableError: Error? + var prepareWithViewingKeysWalletBirthdayCallsCount = 0 + var prepareWithViewingKeysWalletBirthdayCalled: Bool { + return prepareWithViewingKeysWalletBirthdayCallsCount > 0 + } + var prepareWithViewingKeysWalletBirthdayReceivedArguments: (seed: [UInt8]?, viewingKeys: [UnifiedFullViewingKey], walletBirthday: BlockHeight)? + var prepareWithViewingKeysWalletBirthdayReturnValue: Initializer.InitializationResult! + var prepareWithViewingKeysWalletBirthdayClosure: (([UInt8]?, [UnifiedFullViewingKey], BlockHeight) async throws -> Initializer.InitializationResult)? + + func prepare(with seed: [UInt8]?, viewingKeys: [UnifiedFullViewingKey], walletBirthday: BlockHeight) async throws -> Initializer.InitializationResult { + if let error = prepareWithViewingKeysWalletBirthdayThrowableError { + throw error + } + prepareWithViewingKeysWalletBirthdayCallsCount += 1 + prepareWithViewingKeysWalletBirthdayReceivedArguments = (seed: seed, viewingKeys: viewingKeys, walletBirthday: walletBirthday) + if let closure = prepareWithViewingKeysWalletBirthdayClosure { + return try await closure(seed, viewingKeys, walletBirthday) + } else { + return prepareWithViewingKeysWalletBirthdayReturnValue + } + } + + // MARK: - start + + var startRetryThrowableError: Error? + var startRetryCallsCount = 0 + var startRetryCalled: Bool { + return startRetryCallsCount > 0 + } + var startRetryReceivedRetry: Bool? + var startRetryClosure: ((Bool) async throws -> Void)? + + func start(retry: Bool) async throws { + if let error = startRetryThrowableError { + throw error + } + startRetryCallsCount += 1 + startRetryReceivedRetry = retry + try await startRetryClosure?(retry) + } + + // MARK: - stop + + var stopCallsCount = 0 + var stopCalled: Bool { + return stopCallsCount > 0 + } + var stopClosure: (() async -> Void)? + + func stop() async { + stopCallsCount += 1 + await stopClosure?() + } + + // MARK: - getSaplingAddress + + var getSaplingAddressAccountIndexThrowableError: Error? + var getSaplingAddressAccountIndexCallsCount = 0 + var getSaplingAddressAccountIndexCalled: Bool { + return getSaplingAddressAccountIndexCallsCount > 0 + } + var getSaplingAddressAccountIndexReceivedAccountIndex: Int? + var getSaplingAddressAccountIndexReturnValue: SaplingAddress! + var getSaplingAddressAccountIndexClosure: ((Int) async throws -> SaplingAddress)? + + func getSaplingAddress(accountIndex: Int) async throws -> SaplingAddress { + if let error = getSaplingAddressAccountIndexThrowableError { + throw error + } + getSaplingAddressAccountIndexCallsCount += 1 + getSaplingAddressAccountIndexReceivedAccountIndex = accountIndex + if let closure = getSaplingAddressAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getSaplingAddressAccountIndexReturnValue + } + } + + // MARK: - getUnifiedAddress + + var getUnifiedAddressAccountIndexThrowableError: Error? + var getUnifiedAddressAccountIndexCallsCount = 0 + var getUnifiedAddressAccountIndexCalled: Bool { + return getUnifiedAddressAccountIndexCallsCount > 0 + } + var getUnifiedAddressAccountIndexReceivedAccountIndex: Int? + var getUnifiedAddressAccountIndexReturnValue: UnifiedAddress! + var getUnifiedAddressAccountIndexClosure: ((Int) async throws -> UnifiedAddress)? + + func getUnifiedAddress(accountIndex: Int) async throws -> UnifiedAddress { + if let error = getUnifiedAddressAccountIndexThrowableError { + throw error + } + getUnifiedAddressAccountIndexCallsCount += 1 + getUnifiedAddressAccountIndexReceivedAccountIndex = accountIndex + if let closure = getUnifiedAddressAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getUnifiedAddressAccountIndexReturnValue + } + } + + // MARK: - getTransparentAddress + + var getTransparentAddressAccountIndexThrowableError: Error? + var getTransparentAddressAccountIndexCallsCount = 0 + var getTransparentAddressAccountIndexCalled: Bool { + return getTransparentAddressAccountIndexCallsCount > 0 + } + var getTransparentAddressAccountIndexReceivedAccountIndex: Int? + var getTransparentAddressAccountIndexReturnValue: TransparentAddress! + var getTransparentAddressAccountIndexClosure: ((Int) async throws -> TransparentAddress)? + + func getTransparentAddress(accountIndex: Int) async throws -> TransparentAddress { + if let error = getTransparentAddressAccountIndexThrowableError { + throw error + } + getTransparentAddressAccountIndexCallsCount += 1 + getTransparentAddressAccountIndexReceivedAccountIndex = accountIndex + if let closure = getTransparentAddressAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getTransparentAddressAccountIndexReturnValue + } + } + + // MARK: - sendToAddress + + var sendToAddressSpendingKeyZatoshiToAddressMemoThrowableError: Error? + var sendToAddressSpendingKeyZatoshiToAddressMemoCallsCount = 0 + var sendToAddressSpendingKeyZatoshiToAddressMemoCalled: Bool { + return sendToAddressSpendingKeyZatoshiToAddressMemoCallsCount > 0 + } + var sendToAddressSpendingKeyZatoshiToAddressMemoReceivedArguments: (spendingKey: UnifiedSpendingKey, zatoshi: Zatoshi, toAddress: Recipient, memo: Memo?)? + var sendToAddressSpendingKeyZatoshiToAddressMemoReturnValue: PendingTransactionEntity! + var sendToAddressSpendingKeyZatoshiToAddressMemoClosure: ((UnifiedSpendingKey, Zatoshi, Recipient, Memo?) async throws -> PendingTransactionEntity)? + + func sendToAddress(spendingKey: UnifiedSpendingKey, zatoshi: Zatoshi, toAddress: Recipient, memo: Memo?) async throws -> PendingTransactionEntity { + if let error = sendToAddressSpendingKeyZatoshiToAddressMemoThrowableError { + throw error + } + sendToAddressSpendingKeyZatoshiToAddressMemoCallsCount += 1 + sendToAddressSpendingKeyZatoshiToAddressMemoReceivedArguments = (spendingKey: spendingKey, zatoshi: zatoshi, toAddress: toAddress, memo: memo) + if let closure = sendToAddressSpendingKeyZatoshiToAddressMemoClosure { + return try await closure(spendingKey, zatoshi, toAddress, memo) + } else { + return sendToAddressSpendingKeyZatoshiToAddressMemoReturnValue + } + } + + // MARK: - shieldFunds + + var shieldFundsSpendingKeyMemoShieldingThresholdThrowableError: Error? + var shieldFundsSpendingKeyMemoShieldingThresholdCallsCount = 0 + var shieldFundsSpendingKeyMemoShieldingThresholdCalled: Bool { + return shieldFundsSpendingKeyMemoShieldingThresholdCallsCount > 0 + } + var shieldFundsSpendingKeyMemoShieldingThresholdReceivedArguments: (spendingKey: UnifiedSpendingKey, memo: Memo, shieldingThreshold: Zatoshi)? + var shieldFundsSpendingKeyMemoShieldingThresholdReturnValue: PendingTransactionEntity! + var shieldFundsSpendingKeyMemoShieldingThresholdClosure: ((UnifiedSpendingKey, Memo, Zatoshi) async throws -> PendingTransactionEntity)? + + func shieldFunds(spendingKey: UnifiedSpendingKey, memo: Memo, shieldingThreshold: Zatoshi) async throws -> PendingTransactionEntity { + if let error = shieldFundsSpendingKeyMemoShieldingThresholdThrowableError { + throw error + } + shieldFundsSpendingKeyMemoShieldingThresholdCallsCount += 1 + shieldFundsSpendingKeyMemoShieldingThresholdReceivedArguments = (spendingKey: spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) + if let closure = shieldFundsSpendingKeyMemoShieldingThresholdClosure { + return try await closure(spendingKey, memo, shieldingThreshold) + } else { + return shieldFundsSpendingKeyMemoShieldingThresholdReturnValue + } + } + + // MARK: - cancelSpend + + var cancelSpendTransactionCallsCount = 0 + var cancelSpendTransactionCalled: Bool { + return cancelSpendTransactionCallsCount > 0 + } + var cancelSpendTransactionReceivedTransaction: PendingTransactionEntity? + var cancelSpendTransactionReturnValue: Bool! + var cancelSpendTransactionClosure: ((PendingTransactionEntity) async -> Bool)? + + func cancelSpend(transaction: PendingTransactionEntity) async -> Bool { + cancelSpendTransactionCallsCount += 1 + cancelSpendTransactionReceivedTransaction = transaction + if let closure = cancelSpendTransactionClosure { + return await closure(transaction) + } else { + return cancelSpendTransactionReturnValue + } + } + + // MARK: - paginatedTransactions + + var paginatedTransactionsOfCallsCount = 0 + var paginatedTransactionsOfCalled: Bool { + return paginatedTransactionsOfCallsCount > 0 + } + var paginatedTransactionsOfReceivedKind: TransactionKind? + var paginatedTransactionsOfReturnValue: PaginatedTransactionRepository! + var paginatedTransactionsOfClosure: ((TransactionKind) -> PaginatedTransactionRepository)? + + func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository { + paginatedTransactionsOfCallsCount += 1 + paginatedTransactionsOfReceivedKind = kind + if let closure = paginatedTransactionsOfClosure { + return closure(kind) + } else { + return paginatedTransactionsOfReturnValue + } + } + + // MARK: - getMemos + + var getMemosForClearedTransactionThrowableError: Error? + var getMemosForClearedTransactionCallsCount = 0 + var getMemosForClearedTransactionCalled: Bool { + return getMemosForClearedTransactionCallsCount > 0 + } + var getMemosForClearedTransactionReceivedTransaction: ZcashTransaction.Overview? + var getMemosForClearedTransactionReturnValue: [Memo]! + var getMemosForClearedTransactionClosure: ((ZcashTransaction.Overview) async throws -> [Memo])? + + func getMemos(for transaction: ZcashTransaction.Overview) async throws -> [Memo] { + if let error = getMemosForClearedTransactionThrowableError { + throw error + } + getMemosForClearedTransactionCallsCount += 1 + getMemosForClearedTransactionReceivedTransaction = transaction + if let closure = getMemosForClearedTransactionClosure { + return try await closure(transaction) + } else { + return getMemosForClearedTransactionReturnValue + } + } + + // MARK: - getMemos + + var getMemosForReceivedTransactionThrowableError: Error? + var getMemosForReceivedTransactionCallsCount = 0 + var getMemosForReceivedTransactionCalled: Bool { + return getMemosForReceivedTransactionCallsCount > 0 + } + var getMemosForReceivedTransactionReceivedReceivedTransaction: ZcashTransaction.Received? + var getMemosForReceivedTransactionReturnValue: [Memo]! + var getMemosForReceivedTransactionClosure: ((ZcashTransaction.Received) async throws -> [Memo])? + + func getMemos(for receivedTransaction: ZcashTransaction.Received) async throws -> [Memo] { + if let error = getMemosForReceivedTransactionThrowableError { + throw error + } + getMemosForReceivedTransactionCallsCount += 1 + getMemosForReceivedTransactionReceivedReceivedTransaction = receivedTransaction + if let closure = getMemosForReceivedTransactionClosure { + return try await closure(receivedTransaction) + } else { + return getMemosForReceivedTransactionReturnValue + } + } + + // MARK: - getMemos + + var getMemosForSentTransactionThrowableError: Error? + var getMemosForSentTransactionCallsCount = 0 + var getMemosForSentTransactionCalled: Bool { + return getMemosForSentTransactionCallsCount > 0 + } + var getMemosForSentTransactionReceivedSentTransaction: ZcashTransaction.Sent? + var getMemosForSentTransactionReturnValue: [Memo]! + var getMemosForSentTransactionClosure: ((ZcashTransaction.Sent) async throws -> [Memo])? + + func getMemos(for sentTransaction: ZcashTransaction.Sent) async throws -> [Memo] { + if let error = getMemosForSentTransactionThrowableError { + throw error + } + getMemosForSentTransactionCallsCount += 1 + getMemosForSentTransactionReceivedSentTransaction = sentTransaction + if let closure = getMemosForSentTransactionClosure { + return try await closure(sentTransaction) + } else { + return getMemosForSentTransactionReturnValue + } + } + + // MARK: - getRecipients + + var getRecipientsForClearedTransactionCallsCount = 0 + var getRecipientsForClearedTransactionCalled: Bool { + return getRecipientsForClearedTransactionCallsCount > 0 + } + var getRecipientsForClearedTransactionReceivedTransaction: ZcashTransaction.Overview? + var getRecipientsForClearedTransactionReturnValue: [TransactionRecipient]! + var getRecipientsForClearedTransactionClosure: ((ZcashTransaction.Overview) async -> [TransactionRecipient])? + + func getRecipients(for transaction: ZcashTransaction.Overview) async -> [TransactionRecipient] { + getRecipientsForClearedTransactionCallsCount += 1 + getRecipientsForClearedTransactionReceivedTransaction = transaction + if let closure = getRecipientsForClearedTransactionClosure { + return await closure(transaction) + } else { + return getRecipientsForClearedTransactionReturnValue + } + } + + // MARK: - getRecipients + + var getRecipientsForSentTransactionCallsCount = 0 + var getRecipientsForSentTransactionCalled: Bool { + return getRecipientsForSentTransactionCallsCount > 0 + } + var getRecipientsForSentTransactionReceivedTransaction: ZcashTransaction.Sent? + var getRecipientsForSentTransactionReturnValue: [TransactionRecipient]! + var getRecipientsForSentTransactionClosure: ((ZcashTransaction.Sent) async -> [TransactionRecipient])? + + func getRecipients(for transaction: ZcashTransaction.Sent) async -> [TransactionRecipient] { + getRecipientsForSentTransactionCallsCount += 1 + getRecipientsForSentTransactionReceivedTransaction = transaction + if let closure = getRecipientsForSentTransactionClosure { + return await closure(transaction) + } else { + return getRecipientsForSentTransactionReturnValue + } + } + + // MARK: - allConfirmedTransactions + + var allConfirmedTransactionsFromLimitThrowableError: Error? + var allConfirmedTransactionsFromLimitCallsCount = 0 + var allConfirmedTransactionsFromLimitCalled: Bool { + return allConfirmedTransactionsFromLimitCallsCount > 0 + } + var allConfirmedTransactionsFromLimitReceivedArguments: (transaction: ZcashTransaction.Overview, limit: Int)? + var allConfirmedTransactionsFromLimitReturnValue: [ZcashTransaction.Overview]! + var allConfirmedTransactionsFromLimitClosure: ((ZcashTransaction.Overview, Int) async throws -> [ZcashTransaction.Overview])? + + func allConfirmedTransactions(from transaction: ZcashTransaction.Overview, limit: Int) async throws -> [ZcashTransaction.Overview] { + if let error = allConfirmedTransactionsFromLimitThrowableError { + throw error + } + allConfirmedTransactionsFromLimitCallsCount += 1 + allConfirmedTransactionsFromLimitReceivedArguments = (transaction: transaction, limit: limit) + if let closure = allConfirmedTransactionsFromLimitClosure { + return try await closure(transaction, limit) + } else { + return allConfirmedTransactionsFromLimitReturnValue + } + } + + // MARK: - latestHeight + + var latestHeightThrowableError: Error? + var latestHeightCallsCount = 0 + var latestHeightCalled: Bool { + return latestHeightCallsCount > 0 + } + var latestHeightReturnValue: BlockHeight! + var latestHeightClosure: (() async throws -> BlockHeight)? + + func latestHeight() async throws -> BlockHeight { + if let error = latestHeightThrowableError { + throw error + } + latestHeightCallsCount += 1 + if let closure = latestHeightClosure { + return try await closure() + } else { + return latestHeightReturnValue + } + } + + // MARK: - refreshUTXOs + + var refreshUTXOsAddressFromThrowableError: Error? + var refreshUTXOsAddressFromCallsCount = 0 + var refreshUTXOsAddressFromCalled: Bool { + return refreshUTXOsAddressFromCallsCount > 0 + } + var refreshUTXOsAddressFromReceivedArguments: (address: TransparentAddress, height: BlockHeight)? + var refreshUTXOsAddressFromReturnValue: RefreshedUTXOs! + var refreshUTXOsAddressFromClosure: ((TransparentAddress, BlockHeight) async throws -> RefreshedUTXOs)? + + func refreshUTXOs(address: TransparentAddress, from height: BlockHeight) async throws -> RefreshedUTXOs { + if let error = refreshUTXOsAddressFromThrowableError { + throw error + } + refreshUTXOsAddressFromCallsCount += 1 + refreshUTXOsAddressFromReceivedArguments = (address: address, height: height) + if let closure = refreshUTXOsAddressFromClosure { + return try await closure(address, height) + } else { + return refreshUTXOsAddressFromReturnValue + } + } + + // MARK: - getTransparentBalance + + var getTransparentBalanceAccountIndexThrowableError: Error? + var getTransparentBalanceAccountIndexCallsCount = 0 + var getTransparentBalanceAccountIndexCalled: Bool { + return getTransparentBalanceAccountIndexCallsCount > 0 + } + var getTransparentBalanceAccountIndexReceivedAccountIndex: Int? + var getTransparentBalanceAccountIndexReturnValue: WalletBalance! + var getTransparentBalanceAccountIndexClosure: ((Int) async throws -> WalletBalance)? + + func getTransparentBalance(accountIndex: Int) async throws -> WalletBalance { + if let error = getTransparentBalanceAccountIndexThrowableError { + throw error + } + getTransparentBalanceAccountIndexCallsCount += 1 + getTransparentBalanceAccountIndexReceivedAccountIndex = accountIndex + if let closure = getTransparentBalanceAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getTransparentBalanceAccountIndexReturnValue + } + } + + // MARK: - getShieldedBalance + + var getShieldedBalanceAccountIndexThrowableError: Error? + var getShieldedBalanceAccountIndexCallsCount = 0 + var getShieldedBalanceAccountIndexCalled: Bool { + return getShieldedBalanceAccountIndexCallsCount > 0 + } + var getShieldedBalanceAccountIndexReceivedAccountIndex: Int? + var getShieldedBalanceAccountIndexReturnValue: Zatoshi! + var getShieldedBalanceAccountIndexClosure: ((Int) async throws -> Zatoshi)? + + func getShieldedBalance(accountIndex: Int) async throws -> Zatoshi { + if let error = getShieldedBalanceAccountIndexThrowableError { + throw error + } + getShieldedBalanceAccountIndexCallsCount += 1 + getShieldedBalanceAccountIndexReceivedAccountIndex = accountIndex + if let closure = getShieldedBalanceAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getShieldedBalanceAccountIndexReturnValue + } + } + + // MARK: - getShieldedVerifiedBalance + + var getShieldedVerifiedBalanceAccountIndexThrowableError: Error? + var getShieldedVerifiedBalanceAccountIndexCallsCount = 0 + var getShieldedVerifiedBalanceAccountIndexCalled: Bool { + return getShieldedVerifiedBalanceAccountIndexCallsCount > 0 + } + var getShieldedVerifiedBalanceAccountIndexReceivedAccountIndex: Int? + var getShieldedVerifiedBalanceAccountIndexReturnValue: Zatoshi! + var getShieldedVerifiedBalanceAccountIndexClosure: ((Int) async throws -> Zatoshi)? + + func getShieldedVerifiedBalance(accountIndex: Int) async throws -> Zatoshi { + if let error = getShieldedVerifiedBalanceAccountIndexThrowableError { + throw error + } + getShieldedVerifiedBalanceAccountIndexCallsCount += 1 + getShieldedVerifiedBalanceAccountIndexReceivedAccountIndex = accountIndex + if let closure = getShieldedVerifiedBalanceAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getShieldedVerifiedBalanceAccountIndexReturnValue + } + } + + // MARK: - rewind + + var rewindCallsCount = 0 + var rewindCalled: Bool { + return rewindCallsCount > 0 + } + var rewindReceivedPolicy: RewindPolicy? + var rewindReturnValue: AnyPublisher! + var rewindClosure: ((RewindPolicy) -> AnyPublisher)? + + func rewind(_ policy: RewindPolicy) -> AnyPublisher { + rewindCallsCount += 1 + rewindReceivedPolicy = policy + if let closure = rewindClosure { + return closure(policy) + } else { + return rewindReturnValue + } + } + + // MARK: - wipe + + var wipeCallsCount = 0 + var wipeCalled: Bool { + return wipeCallsCount > 0 + } + var wipeReturnValue: AnyPublisher! + var wipeClosure: (() -> AnyPublisher)? + + func wipe() -> AnyPublisher { + wipeCallsCount += 1 + if let closure = wipeClosure { + return closure() + } else { + return wipeReturnValue + } + } + +} +actor ZcashRustBackendWeldingMock: ZcashRustBackendWelding { + + nonisolated let consensusBranchIdForHeightClosure: ((Int32) throws -> Int32)? + + init( + consensusBranchIdForHeightClosure: ((Int32) throws -> Int32)? = nil + ) { + self.consensusBranchIdForHeightClosure = consensusBranchIdForHeightClosure + } + + // MARK: - createAccount + + var createAccountSeedThrowableError: Error? + func setCreateAccountSeedThrowableError(_ param: Error?) async { + createAccountSeedThrowableError = param + } + var createAccountSeedCallsCount = 0 + var createAccountSeedCalled: Bool { + return createAccountSeedCallsCount > 0 + } + var createAccountSeedReceivedSeed: [UInt8]? + var createAccountSeedReturnValue: UnifiedSpendingKey! + func setCreateAccountSeedReturnValue(_ param: UnifiedSpendingKey) async { + createAccountSeedReturnValue = param + } + var createAccountSeedClosure: (([UInt8]) async throws -> UnifiedSpendingKey)? + func setCreateAccountSeedClosure(_ param: (([UInt8]) async throws -> UnifiedSpendingKey)?) async { + createAccountSeedClosure = param + } + + func createAccount(seed: [UInt8]) async throws -> UnifiedSpendingKey { + if let error = createAccountSeedThrowableError { + throw error + } + createAccountSeedCallsCount += 1 + createAccountSeedReceivedSeed = seed + if let closure = createAccountSeedClosure { + return try await closure(seed) + } else { + return createAccountSeedReturnValue + } + } + + // MARK: - createToAddress + + var createToAddressUskToValueMemoThrowableError: Error? + func setCreateToAddressUskToValueMemoThrowableError(_ param: Error?) async { + createToAddressUskToValueMemoThrowableError = param + } + var createToAddressUskToValueMemoCallsCount = 0 + var createToAddressUskToValueMemoCalled: Bool { + return createToAddressUskToValueMemoCallsCount > 0 + } + var createToAddressUskToValueMemoReceivedArguments: (usk: UnifiedSpendingKey, address: String, value: Int64, memo: MemoBytes?)? + var createToAddressUskToValueMemoReturnValue: Int64! + func setCreateToAddressUskToValueMemoReturnValue(_ param: Int64) async { + createToAddressUskToValueMemoReturnValue = param + } + var createToAddressUskToValueMemoClosure: ((UnifiedSpendingKey, String, Int64, MemoBytes?) async throws -> Int64)? + func setCreateToAddressUskToValueMemoClosure(_ param: ((UnifiedSpendingKey, String, Int64, MemoBytes?) async throws -> Int64)?) async { + createToAddressUskToValueMemoClosure = param + } + + func createToAddress(usk: UnifiedSpendingKey, to address: String, value: Int64, memo: MemoBytes?) async throws -> Int64 { + if let error = createToAddressUskToValueMemoThrowableError { + throw error + } + createToAddressUskToValueMemoCallsCount += 1 + createToAddressUskToValueMemoReceivedArguments = (usk: usk, address: address, value: value, memo: memo) + if let closure = createToAddressUskToValueMemoClosure { + return try await closure(usk, address, value, memo) + } else { + return createToAddressUskToValueMemoReturnValue + } + } + + // MARK: - decryptAndStoreTransaction + + var decryptAndStoreTransactionTxBytesMinedHeightThrowableError: Error? + func setDecryptAndStoreTransactionTxBytesMinedHeightThrowableError(_ param: Error?) async { + decryptAndStoreTransactionTxBytesMinedHeightThrowableError = param + } + var decryptAndStoreTransactionTxBytesMinedHeightCallsCount = 0 + var decryptAndStoreTransactionTxBytesMinedHeightCalled: Bool { + return decryptAndStoreTransactionTxBytesMinedHeightCallsCount > 0 + } + var decryptAndStoreTransactionTxBytesMinedHeightReceivedArguments: (txBytes: [UInt8], minedHeight: Int32)? + var decryptAndStoreTransactionTxBytesMinedHeightClosure: (([UInt8], Int32) async throws -> Void)? + func setDecryptAndStoreTransactionTxBytesMinedHeightClosure(_ param: (([UInt8], Int32) async throws -> Void)?) async { + decryptAndStoreTransactionTxBytesMinedHeightClosure = param + } + + func decryptAndStoreTransaction(txBytes: [UInt8], minedHeight: Int32) async throws { + if let error = decryptAndStoreTransactionTxBytesMinedHeightThrowableError { + throw error + } + decryptAndStoreTransactionTxBytesMinedHeightCallsCount += 1 + decryptAndStoreTransactionTxBytesMinedHeightReceivedArguments = (txBytes: txBytes, minedHeight: minedHeight) + try await decryptAndStoreTransactionTxBytesMinedHeightClosure?(txBytes, minedHeight) + } + + // MARK: - getBalance + + var getBalanceAccountThrowableError: Error? + func setGetBalanceAccountThrowableError(_ param: Error?) async { + getBalanceAccountThrowableError = param + } + var getBalanceAccountCallsCount = 0 + var getBalanceAccountCalled: Bool { + return getBalanceAccountCallsCount > 0 + } + var getBalanceAccountReceivedAccount: Int32? + var getBalanceAccountReturnValue: Int64! + func setGetBalanceAccountReturnValue(_ param: Int64) async { + getBalanceAccountReturnValue = param + } + var getBalanceAccountClosure: ((Int32) async throws -> Int64)? + func setGetBalanceAccountClosure(_ param: ((Int32) async throws -> Int64)?) async { + getBalanceAccountClosure = param + } + + func getBalance(account: Int32) async throws -> Int64 { + if let error = getBalanceAccountThrowableError { + throw error + } + getBalanceAccountCallsCount += 1 + getBalanceAccountReceivedAccount = account + if let closure = getBalanceAccountClosure { + return try await closure(account) + } else { + return getBalanceAccountReturnValue + } + } + + // MARK: - getCurrentAddress + + var getCurrentAddressAccountThrowableError: Error? + func setGetCurrentAddressAccountThrowableError(_ param: Error?) async { + getCurrentAddressAccountThrowableError = param + } + var getCurrentAddressAccountCallsCount = 0 + var getCurrentAddressAccountCalled: Bool { + return getCurrentAddressAccountCallsCount > 0 + } + var getCurrentAddressAccountReceivedAccount: Int32? + var getCurrentAddressAccountReturnValue: UnifiedAddress! + func setGetCurrentAddressAccountReturnValue(_ param: UnifiedAddress) async { + getCurrentAddressAccountReturnValue = param + } + var getCurrentAddressAccountClosure: ((Int32) async throws -> UnifiedAddress)? + func setGetCurrentAddressAccountClosure(_ param: ((Int32) async throws -> UnifiedAddress)?) async { + getCurrentAddressAccountClosure = param + } + + func getCurrentAddress(account: Int32) async throws -> UnifiedAddress { + if let error = getCurrentAddressAccountThrowableError { + throw error + } + getCurrentAddressAccountCallsCount += 1 + getCurrentAddressAccountReceivedAccount = account + if let closure = getCurrentAddressAccountClosure { + return try await closure(account) + } else { + return getCurrentAddressAccountReturnValue + } + } + + // MARK: - getNearestRewindHeight + + var getNearestRewindHeightHeightThrowableError: Error? + func setGetNearestRewindHeightHeightThrowableError(_ param: Error?) async { + getNearestRewindHeightHeightThrowableError = param + } + var getNearestRewindHeightHeightCallsCount = 0 + var getNearestRewindHeightHeightCalled: Bool { + return getNearestRewindHeightHeightCallsCount > 0 + } + var getNearestRewindHeightHeightReceivedHeight: Int32? + var getNearestRewindHeightHeightReturnValue: Int32! + func setGetNearestRewindHeightHeightReturnValue(_ param: Int32) async { + getNearestRewindHeightHeightReturnValue = param + } + var getNearestRewindHeightHeightClosure: ((Int32) async throws -> Int32)? + func setGetNearestRewindHeightHeightClosure(_ param: ((Int32) async throws -> Int32)?) async { + getNearestRewindHeightHeightClosure = param + } + + func getNearestRewindHeight(height: Int32) async throws -> Int32 { + if let error = getNearestRewindHeightHeightThrowableError { + throw error + } + getNearestRewindHeightHeightCallsCount += 1 + getNearestRewindHeightHeightReceivedHeight = height + if let closure = getNearestRewindHeightHeightClosure { + return try await closure(height) + } else { + return getNearestRewindHeightHeightReturnValue + } + } + + // MARK: - getNextAvailableAddress + + var getNextAvailableAddressAccountThrowableError: Error? + func setGetNextAvailableAddressAccountThrowableError(_ param: Error?) async { + getNextAvailableAddressAccountThrowableError = param + } + var getNextAvailableAddressAccountCallsCount = 0 + var getNextAvailableAddressAccountCalled: Bool { + return getNextAvailableAddressAccountCallsCount > 0 + } + var getNextAvailableAddressAccountReceivedAccount: Int32? + var getNextAvailableAddressAccountReturnValue: UnifiedAddress! + func setGetNextAvailableAddressAccountReturnValue(_ param: UnifiedAddress) async { + getNextAvailableAddressAccountReturnValue = param + } + var getNextAvailableAddressAccountClosure: ((Int32) async throws -> UnifiedAddress)? + func setGetNextAvailableAddressAccountClosure(_ param: ((Int32) async throws -> UnifiedAddress)?) async { + getNextAvailableAddressAccountClosure = param + } + + func getNextAvailableAddress(account: Int32) async throws -> UnifiedAddress { + if let error = getNextAvailableAddressAccountThrowableError { + throw error + } + getNextAvailableAddressAccountCallsCount += 1 + getNextAvailableAddressAccountReceivedAccount = account + if let closure = getNextAvailableAddressAccountClosure { + return try await closure(account) + } else { + return getNextAvailableAddressAccountReturnValue + } + } + + // MARK: - getReceivedMemo + + var getReceivedMemoIdNoteCallsCount = 0 + var getReceivedMemoIdNoteCalled: Bool { + return getReceivedMemoIdNoteCallsCount > 0 + } + var getReceivedMemoIdNoteReceivedIdNote: Int64? + var getReceivedMemoIdNoteReturnValue: Memo? + func setGetReceivedMemoIdNoteReturnValue(_ param: Memo?) async { + getReceivedMemoIdNoteReturnValue = param + } + var getReceivedMemoIdNoteClosure: ((Int64) async -> Memo?)? + func setGetReceivedMemoIdNoteClosure(_ param: ((Int64) async -> Memo?)?) async { + getReceivedMemoIdNoteClosure = param + } + + func getReceivedMemo(idNote: Int64) async -> Memo? { + getReceivedMemoIdNoteCallsCount += 1 + getReceivedMemoIdNoteReceivedIdNote = idNote + if let closure = getReceivedMemoIdNoteClosure { + return await closure(idNote) + } else { + return getReceivedMemoIdNoteReturnValue + } + } + + // MARK: - getSentMemo + + var getSentMemoIdNoteCallsCount = 0 + var getSentMemoIdNoteCalled: Bool { + return getSentMemoIdNoteCallsCount > 0 + } + var getSentMemoIdNoteReceivedIdNote: Int64? + var getSentMemoIdNoteReturnValue: Memo? + func setGetSentMemoIdNoteReturnValue(_ param: Memo?) async { + getSentMemoIdNoteReturnValue = param + } + var getSentMemoIdNoteClosure: ((Int64) async -> Memo?)? + func setGetSentMemoIdNoteClosure(_ param: ((Int64) async -> Memo?)?) async { + getSentMemoIdNoteClosure = param + } + + func getSentMemo(idNote: Int64) async -> Memo? { + getSentMemoIdNoteCallsCount += 1 + getSentMemoIdNoteReceivedIdNote = idNote + if let closure = getSentMemoIdNoteClosure { + return await closure(idNote) + } else { + return getSentMemoIdNoteReturnValue + } + } + + // MARK: - getTransparentBalance + + var getTransparentBalanceAccountThrowableError: Error? + func setGetTransparentBalanceAccountThrowableError(_ param: Error?) async { + getTransparentBalanceAccountThrowableError = param + } + var getTransparentBalanceAccountCallsCount = 0 + var getTransparentBalanceAccountCalled: Bool { + return getTransparentBalanceAccountCallsCount > 0 + } + var getTransparentBalanceAccountReceivedAccount: Int32? + var getTransparentBalanceAccountReturnValue: Int64! + func setGetTransparentBalanceAccountReturnValue(_ param: Int64) async { + getTransparentBalanceAccountReturnValue = param + } + var getTransparentBalanceAccountClosure: ((Int32) async throws -> Int64)? + func setGetTransparentBalanceAccountClosure(_ param: ((Int32) async throws -> Int64)?) async { + getTransparentBalanceAccountClosure = param + } + + func getTransparentBalance(account: Int32) async throws -> Int64 { + if let error = getTransparentBalanceAccountThrowableError { + throw error + } + getTransparentBalanceAccountCallsCount += 1 + getTransparentBalanceAccountReceivedAccount = account + if let closure = getTransparentBalanceAccountClosure { + return try await closure(account) + } else { + return getTransparentBalanceAccountReturnValue + } + } + + // MARK: - initAccountsTable + + var initAccountsTableUfvksThrowableError: Error? + func setInitAccountsTableUfvksThrowableError(_ param: Error?) async { + initAccountsTableUfvksThrowableError = param + } + var initAccountsTableUfvksCallsCount = 0 + var initAccountsTableUfvksCalled: Bool { + return initAccountsTableUfvksCallsCount > 0 + } + var initAccountsTableUfvksReceivedUfvks: [UnifiedFullViewingKey]? + var initAccountsTableUfvksClosure: (([UnifiedFullViewingKey]) async throws -> Void)? + func setInitAccountsTableUfvksClosure(_ param: (([UnifiedFullViewingKey]) async throws -> Void)?) async { + initAccountsTableUfvksClosure = param + } + + func initAccountsTable(ufvks: [UnifiedFullViewingKey]) async throws { + if let error = initAccountsTableUfvksThrowableError { + throw error + } + initAccountsTableUfvksCallsCount += 1 + initAccountsTableUfvksReceivedUfvks = ufvks + try await initAccountsTableUfvksClosure?(ufvks) + } + + // MARK: - initDataDb + + var initDataDbSeedThrowableError: Error? + func setInitDataDbSeedThrowableError(_ param: Error?) async { + initDataDbSeedThrowableError = param + } + var initDataDbSeedCallsCount = 0 + var initDataDbSeedCalled: Bool { + return initDataDbSeedCallsCount > 0 + } + var initDataDbSeedReceivedSeed: [UInt8]? + var initDataDbSeedReturnValue: DbInitResult! + func setInitDataDbSeedReturnValue(_ param: DbInitResult) async { + initDataDbSeedReturnValue = param + } + var initDataDbSeedClosure: (([UInt8]?) async throws -> DbInitResult)? + func setInitDataDbSeedClosure(_ param: (([UInt8]?) async throws -> DbInitResult)?) async { + initDataDbSeedClosure = param + } + + func initDataDb(seed: [UInt8]?) async throws -> DbInitResult { + if let error = initDataDbSeedThrowableError { + throw error + } + initDataDbSeedCallsCount += 1 + initDataDbSeedReceivedSeed = seed + if let closure = initDataDbSeedClosure { + return try await closure(seed) + } else { + return initDataDbSeedReturnValue + } + } + + // MARK: - initBlocksTable + + var initBlocksTableHeightHashTimeSaplingTreeThrowableError: Error? + func setInitBlocksTableHeightHashTimeSaplingTreeThrowableError(_ param: Error?) async { + initBlocksTableHeightHashTimeSaplingTreeThrowableError = param + } + var initBlocksTableHeightHashTimeSaplingTreeCallsCount = 0 + var initBlocksTableHeightHashTimeSaplingTreeCalled: Bool { + return initBlocksTableHeightHashTimeSaplingTreeCallsCount > 0 + } + var initBlocksTableHeightHashTimeSaplingTreeReceivedArguments: (height: Int32, hash: String, time: UInt32, saplingTree: String)? + var initBlocksTableHeightHashTimeSaplingTreeClosure: ((Int32, String, UInt32, String) async throws -> Void)? + func setInitBlocksTableHeightHashTimeSaplingTreeClosure(_ param: ((Int32, String, UInt32, String) async throws -> Void)?) async { + initBlocksTableHeightHashTimeSaplingTreeClosure = param + } + + func initBlocksTable(height: Int32, hash: String, time: UInt32, saplingTree: String) async throws { + if let error = initBlocksTableHeightHashTimeSaplingTreeThrowableError { + throw error + } + initBlocksTableHeightHashTimeSaplingTreeCallsCount += 1 + initBlocksTableHeightHashTimeSaplingTreeReceivedArguments = (height: height, hash: hash, time: time, saplingTree: saplingTree) + try await initBlocksTableHeightHashTimeSaplingTreeClosure?(height, hash, time, saplingTree) + } + + // MARK: - listTransparentReceivers + + var listTransparentReceiversAccountThrowableError: Error? + func setListTransparentReceiversAccountThrowableError(_ param: Error?) async { + listTransparentReceiversAccountThrowableError = param + } + var listTransparentReceiversAccountCallsCount = 0 + var listTransparentReceiversAccountCalled: Bool { + return listTransparentReceiversAccountCallsCount > 0 + } + var listTransparentReceiversAccountReceivedAccount: Int32? + var listTransparentReceiversAccountReturnValue: [TransparentAddress]! + func setListTransparentReceiversAccountReturnValue(_ param: [TransparentAddress]) async { + listTransparentReceiversAccountReturnValue = param + } + var listTransparentReceiversAccountClosure: ((Int32) async throws -> [TransparentAddress])? + func setListTransparentReceiversAccountClosure(_ param: ((Int32) async throws -> [TransparentAddress])?) async { + listTransparentReceiversAccountClosure = param + } + + func listTransparentReceivers(account: Int32) async throws -> [TransparentAddress] { + if let error = listTransparentReceiversAccountThrowableError { + throw error + } + listTransparentReceiversAccountCallsCount += 1 + listTransparentReceiversAccountReceivedAccount = account + if let closure = listTransparentReceiversAccountClosure { + return try await closure(account) + } else { + return listTransparentReceiversAccountReturnValue + } + } + + // MARK: - getVerifiedBalance + + var getVerifiedBalanceAccountThrowableError: Error? + func setGetVerifiedBalanceAccountThrowableError(_ param: Error?) async { + getVerifiedBalanceAccountThrowableError = param + } + var getVerifiedBalanceAccountCallsCount = 0 + var getVerifiedBalanceAccountCalled: Bool { + return getVerifiedBalanceAccountCallsCount > 0 + } + var getVerifiedBalanceAccountReceivedAccount: Int32? + var getVerifiedBalanceAccountReturnValue: Int64! + func setGetVerifiedBalanceAccountReturnValue(_ param: Int64) async { + getVerifiedBalanceAccountReturnValue = param + } + var getVerifiedBalanceAccountClosure: ((Int32) async throws -> Int64)? + func setGetVerifiedBalanceAccountClosure(_ param: ((Int32) async throws -> Int64)?) async { + getVerifiedBalanceAccountClosure = param + } + + func getVerifiedBalance(account: Int32) async throws -> Int64 { + if let error = getVerifiedBalanceAccountThrowableError { + throw error + } + getVerifiedBalanceAccountCallsCount += 1 + getVerifiedBalanceAccountReceivedAccount = account + if let closure = getVerifiedBalanceAccountClosure { + return try await closure(account) + } else { + return getVerifiedBalanceAccountReturnValue + } + } + + // MARK: - getVerifiedTransparentBalance + + var getVerifiedTransparentBalanceAccountThrowableError: Error? + func setGetVerifiedTransparentBalanceAccountThrowableError(_ param: Error?) async { + getVerifiedTransparentBalanceAccountThrowableError = param + } + var getVerifiedTransparentBalanceAccountCallsCount = 0 + var getVerifiedTransparentBalanceAccountCalled: Bool { + return getVerifiedTransparentBalanceAccountCallsCount > 0 + } + var getVerifiedTransparentBalanceAccountReceivedAccount: Int32? + var getVerifiedTransparentBalanceAccountReturnValue: Int64! + func setGetVerifiedTransparentBalanceAccountReturnValue(_ param: Int64) async { + getVerifiedTransparentBalanceAccountReturnValue = param + } + var getVerifiedTransparentBalanceAccountClosure: ((Int32) async throws -> Int64)? + func setGetVerifiedTransparentBalanceAccountClosure(_ param: ((Int32) async throws -> Int64)?) async { + getVerifiedTransparentBalanceAccountClosure = param + } + + func getVerifiedTransparentBalance(account: Int32) async throws -> Int64 { + if let error = getVerifiedTransparentBalanceAccountThrowableError { + throw error + } + getVerifiedTransparentBalanceAccountCallsCount += 1 + getVerifiedTransparentBalanceAccountReceivedAccount = account + if let closure = getVerifiedTransparentBalanceAccountClosure { + return try await closure(account) + } else { + return getVerifiedTransparentBalanceAccountReturnValue + } + } + + // MARK: - validateCombinedChain + + var validateCombinedChainLimitThrowableError: Error? + func setValidateCombinedChainLimitThrowableError(_ param: Error?) async { + validateCombinedChainLimitThrowableError = param + } + var validateCombinedChainLimitCallsCount = 0 + var validateCombinedChainLimitCalled: Bool { + return validateCombinedChainLimitCallsCount > 0 + } + var validateCombinedChainLimitReceivedLimit: UInt32? + var validateCombinedChainLimitClosure: ((UInt32) async throws -> Void)? + func setValidateCombinedChainLimitClosure(_ param: ((UInt32) async throws -> Void)?) async { + validateCombinedChainLimitClosure = param + } + + func validateCombinedChain(limit: UInt32) async throws { + if let error = validateCombinedChainLimitThrowableError { + throw error + } + validateCombinedChainLimitCallsCount += 1 + validateCombinedChainLimitReceivedLimit = limit + try await validateCombinedChainLimitClosure?(limit) + } + + // MARK: - rewindToHeight + + var rewindToHeightHeightThrowableError: Error? + func setRewindToHeightHeightThrowableError(_ param: Error?) async { + rewindToHeightHeightThrowableError = param + } + var rewindToHeightHeightCallsCount = 0 + var rewindToHeightHeightCalled: Bool { + return rewindToHeightHeightCallsCount > 0 + } + var rewindToHeightHeightReceivedHeight: Int32? + var rewindToHeightHeightClosure: ((Int32) async throws -> Void)? + func setRewindToHeightHeightClosure(_ param: ((Int32) async throws -> Void)?) async { + rewindToHeightHeightClosure = param + } + + func rewindToHeight(height: Int32) async throws { + if let error = rewindToHeightHeightThrowableError { + throw error + } + rewindToHeightHeightCallsCount += 1 + rewindToHeightHeightReceivedHeight = height + try await rewindToHeightHeightClosure?(height) + } + + // MARK: - rewindCacheToHeight + + var rewindCacheToHeightHeightThrowableError: Error? + func setRewindCacheToHeightHeightThrowableError(_ param: Error?) async { + rewindCacheToHeightHeightThrowableError = param + } + var rewindCacheToHeightHeightCallsCount = 0 + var rewindCacheToHeightHeightCalled: Bool { + return rewindCacheToHeightHeightCallsCount > 0 + } + var rewindCacheToHeightHeightReceivedHeight: Int32? + var rewindCacheToHeightHeightClosure: ((Int32) async throws -> Void)? + func setRewindCacheToHeightHeightClosure(_ param: ((Int32) async throws -> Void)?) async { + rewindCacheToHeightHeightClosure = param + } + + func rewindCacheToHeight(height: Int32) async throws { + if let error = rewindCacheToHeightHeightThrowableError { + throw error + } + rewindCacheToHeightHeightCallsCount += 1 + rewindCacheToHeightHeightReceivedHeight = height + try await rewindCacheToHeightHeightClosure?(height) + } + + // MARK: - scanBlocks + + var scanBlocksLimitThrowableError: Error? + func setScanBlocksLimitThrowableError(_ param: Error?) async { + scanBlocksLimitThrowableError = param + } + var scanBlocksLimitCallsCount = 0 + var scanBlocksLimitCalled: Bool { + return scanBlocksLimitCallsCount > 0 + } + var scanBlocksLimitReceivedLimit: UInt32? + var scanBlocksLimitClosure: ((UInt32) async throws -> Void)? + func setScanBlocksLimitClosure(_ param: ((UInt32) async throws -> Void)?) async { + scanBlocksLimitClosure = param + } + + func scanBlocks(limit: UInt32) async throws { + if let error = scanBlocksLimitThrowableError { + throw error + } + scanBlocksLimitCallsCount += 1 + scanBlocksLimitReceivedLimit = limit + try await scanBlocksLimitClosure?(limit) + } + + // MARK: - putUnspentTransparentOutput + + var putUnspentTransparentOutputTxidIndexScriptValueHeightThrowableError: Error? + func setPutUnspentTransparentOutputTxidIndexScriptValueHeightThrowableError(_ param: Error?) async { + putUnspentTransparentOutputTxidIndexScriptValueHeightThrowableError = param + } + var putUnspentTransparentOutputTxidIndexScriptValueHeightCallsCount = 0 + var putUnspentTransparentOutputTxidIndexScriptValueHeightCalled: Bool { + return putUnspentTransparentOutputTxidIndexScriptValueHeightCallsCount > 0 + } + var putUnspentTransparentOutputTxidIndexScriptValueHeightReceivedArguments: (txid: [UInt8], index: Int, script: [UInt8], value: Int64, height: BlockHeight)? + var putUnspentTransparentOutputTxidIndexScriptValueHeightClosure: (([UInt8], Int, [UInt8], Int64, BlockHeight) async throws -> Void)? + func setPutUnspentTransparentOutputTxidIndexScriptValueHeightClosure(_ param: (([UInt8], Int, [UInt8], Int64, BlockHeight) async throws -> Void)?) async { + putUnspentTransparentOutputTxidIndexScriptValueHeightClosure = param + } + + func putUnspentTransparentOutput(txid: [UInt8], index: Int, script: [UInt8], value: Int64, height: BlockHeight) async throws { + if let error = putUnspentTransparentOutputTxidIndexScriptValueHeightThrowableError { + throw error + } + putUnspentTransparentOutputTxidIndexScriptValueHeightCallsCount += 1 + putUnspentTransparentOutputTxidIndexScriptValueHeightReceivedArguments = (txid: txid, index: index, script: script, value: value, height: height) + try await putUnspentTransparentOutputTxidIndexScriptValueHeightClosure?(txid, index, script, value, height) + } + + // MARK: - shieldFunds + + var shieldFundsUskMemoShieldingThresholdThrowableError: Error? + func setShieldFundsUskMemoShieldingThresholdThrowableError(_ param: Error?) async { + shieldFundsUskMemoShieldingThresholdThrowableError = param + } + var shieldFundsUskMemoShieldingThresholdCallsCount = 0 + var shieldFundsUskMemoShieldingThresholdCalled: Bool { + return shieldFundsUskMemoShieldingThresholdCallsCount > 0 + } + var shieldFundsUskMemoShieldingThresholdReceivedArguments: (usk: UnifiedSpendingKey, memo: MemoBytes?, shieldingThreshold: Zatoshi)? + var shieldFundsUskMemoShieldingThresholdReturnValue: Int64! + func setShieldFundsUskMemoShieldingThresholdReturnValue(_ param: Int64) async { + shieldFundsUskMemoShieldingThresholdReturnValue = param + } + var shieldFundsUskMemoShieldingThresholdClosure: ((UnifiedSpendingKey, MemoBytes?, Zatoshi) async throws -> Int64)? + func setShieldFundsUskMemoShieldingThresholdClosure(_ param: ((UnifiedSpendingKey, MemoBytes?, Zatoshi) async throws -> Int64)?) async { + shieldFundsUskMemoShieldingThresholdClosure = param + } + + func shieldFunds(usk: UnifiedSpendingKey, memo: MemoBytes?, shieldingThreshold: Zatoshi) async throws -> Int64 { + if let error = shieldFundsUskMemoShieldingThresholdThrowableError { + throw error + } + shieldFundsUskMemoShieldingThresholdCallsCount += 1 + shieldFundsUskMemoShieldingThresholdReceivedArguments = (usk: usk, memo: memo, shieldingThreshold: shieldingThreshold) + if let closure = shieldFundsUskMemoShieldingThresholdClosure { + return try await closure(usk, memo, shieldingThreshold) + } else { + return shieldFundsUskMemoShieldingThresholdReturnValue + } + } + + // MARK: - consensusBranchIdFor + + + nonisolated func consensusBranchIdFor(height: Int32) throws -> Int32 { + try consensusBranchIdForHeightClosure!(height) + } + + // MARK: - initBlockMetadataDb + + var initBlockMetadataDbThrowableError: Error? + func setInitBlockMetadataDbThrowableError(_ param: Error?) async { + initBlockMetadataDbThrowableError = param + } + var initBlockMetadataDbCallsCount = 0 + var initBlockMetadataDbCalled: Bool { + return initBlockMetadataDbCallsCount > 0 + } + var initBlockMetadataDbClosure: (() async throws -> Void)? + func setInitBlockMetadataDbClosure(_ param: (() async throws -> Void)?) async { + initBlockMetadataDbClosure = param + } + + func initBlockMetadataDb() async throws { + if let error = initBlockMetadataDbThrowableError { + throw error + } + initBlockMetadataDbCallsCount += 1 + try await initBlockMetadataDbClosure?() + } + + // MARK: - writeBlocksMetadata + + var writeBlocksMetadataBlocksThrowableError: Error? + func setWriteBlocksMetadataBlocksThrowableError(_ param: Error?) async { + writeBlocksMetadataBlocksThrowableError = param + } + var writeBlocksMetadataBlocksCallsCount = 0 + var writeBlocksMetadataBlocksCalled: Bool { + return writeBlocksMetadataBlocksCallsCount > 0 + } + var writeBlocksMetadataBlocksReceivedBlocks: [ZcashCompactBlock]? + var writeBlocksMetadataBlocksClosure: (([ZcashCompactBlock]) async throws -> Void)? + func setWriteBlocksMetadataBlocksClosure(_ param: (([ZcashCompactBlock]) async throws -> Void)?) async { + writeBlocksMetadataBlocksClosure = param + } + + func writeBlocksMetadata(blocks: [ZcashCompactBlock]) async throws { + if let error = writeBlocksMetadataBlocksThrowableError { + throw error + } + writeBlocksMetadataBlocksCallsCount += 1 + writeBlocksMetadataBlocksReceivedBlocks = blocks + try await writeBlocksMetadataBlocksClosure?(blocks) + } + + // MARK: - latestCachedBlockHeight + + var latestCachedBlockHeightCallsCount = 0 + var latestCachedBlockHeightCalled: Bool { + return latestCachedBlockHeightCallsCount > 0 + } + var latestCachedBlockHeightReturnValue: BlockHeight! + func setLatestCachedBlockHeightReturnValue(_ param: BlockHeight) async { + latestCachedBlockHeightReturnValue = param + } + var latestCachedBlockHeightClosure: (() async -> BlockHeight)? + func setLatestCachedBlockHeightClosure(_ param: (() async -> BlockHeight)?) async { + latestCachedBlockHeightClosure = param + } + + func latestCachedBlockHeight() async -> BlockHeight { + latestCachedBlockHeightCallsCount += 1 + if let closure = latestCachedBlockHeightClosure { + return await closure() + } else { + return latestCachedBlockHeightReturnValue + } + } + +} diff --git a/Tests/TestUtils/Sourcery/generateMocks.sh b/Tests/TestUtils/Sourcery/generateMocks.sh new file mode 100755 index 00000000..16bde030 --- /dev/null +++ b/Tests/TestUtils/Sourcery/generateMocks.sh @@ -0,0 +1,22 @@ +#!/bin/zsh + +sourcery_version=2.0.2 + +if which sourcery >/dev/null; then + if [[ $(sourcery --version) != $sourcery_version ]]; then + echo "warning: Compatible sourcery version not installed. Install sourcer $sourcery_version. Currently installed version is $(sourcery --version)" + exit 1 + fi + + sourcery \ + --sources ./ \ + --sources ../../../Sources/ \ + --templates AutoMockable.stencil \ + --output GeneratedMocks/ + +else + echo "warning: sourcery not installed" +fi + + + diff --git a/Tests/TestUtils/Stubs.swift b/Tests/TestUtils/Stubs.swift index 14be95d5..3b32e519 100644 --- a/Tests/TestUtils/Stubs.swift +++ b/Tests/TestUtils/Stubs.swift @@ -49,342 +49,128 @@ extension LightWalletServiceMockResponse { } } -class MockRustBackend: ZcashRustBackendWelding { - static var networkType = NetworkType.testnet - static var mockDataDb = false - static var mockAcounts = false - static var mockError: RustWeldingError? - static var mockLastError: String? - static var mockAccounts: [SaplingExtendedSpendingKey]? - static var mockAddresses: [String]? - static var mockBalance: Int64? - static var mockVerifiedBalance: Int64? - static var mockMemo: String? - static var mockSentMemo: String? - static var mockValidateCombinedChainSuccessRate: Float? - static var mockValidateCombinedChainFailAfterAttempts: Int? - static var mockValidateCombinedChainKeepFailing = false - static var mockValidateCombinedChainFailureHeight: BlockHeight = 0 - static var mockScanblocksSuccessRate: Float? - static var mockCreateToAddress: Int64? - static var rustBackend = ZcashRustBackend.self - static var consensusBranchID: Int32? - static var writeBlocksMetadataResult: () throws -> Bool = { true } - static var rewindCacheToHeightResult: () -> Bool = { true } - static func latestCachedBlockHeight(fsBlockDbRoot: URL) async -> ZcashLightClientKit.BlockHeight { - .empty() +class RustBackendMockHelper { + let rustBackendMock: ZcashRustBackendWeldingMock + var mockValidateCombinedChainFailAfterAttempts: Int? + + init( + rustBackend: ZcashRustBackendWelding, + consensusBranchID: Int32? = nil, + mockValidateCombinedChainSuccessRate: Float? = nil, + mockValidateCombinedChainFailAfterAttempts: Int? = nil, + mockValidateCombinedChainKeepFailing: Bool = false, + mockValidateCombinedChainFailureError: RustWeldingError = .chainValidationFailed(message: nil) + ) async { + self.mockValidateCombinedChainFailAfterAttempts = mockValidateCombinedChainFailAfterAttempts + self.rustBackendMock = ZcashRustBackendWeldingMock( + consensusBranchIdForHeightClosure: { height in + if let consensusBranchID { + return consensusBranchID + } else { + return try rustBackend.consensusBranchIdFor(height: height) + } + } + ) + await setupDefaultMock( + rustBackend: rustBackend, + mockValidateCombinedChainSuccessRate: mockValidateCombinedChainSuccessRate, + mockValidateCombinedChainKeepFailing: mockValidateCombinedChainKeepFailing, + mockValidateCombinedChainFailureError: mockValidateCombinedChainFailureError + ) } - static func rewindCacheToHeight(fsBlockDbRoot: URL, height: Int32) async -> Bool { - rewindCacheToHeightResult() - } + private func setupDefaultMock( + rustBackend: ZcashRustBackendWelding, + mockValidateCombinedChainSuccessRate: Float? = nil, + mockValidateCombinedChainKeepFailing: Bool = false, + mockValidateCombinedChainFailureError: RustWeldingError = .chainValidationFailed(message: nil) + ) async { + await rustBackendMock.setLatestCachedBlockHeightReturnValue(.empty()) + await rustBackendMock.setInitBlockMetadataDbClosure() { } + await rustBackendMock.setWriteBlocksMetadataBlocksClosure() { _ in } + await rustBackendMock.setInitAccountsTableUfvksClosure() { _ in } + await rustBackendMock.setCreateToAddressUskToValueMemoReturnValue(-1) + await rustBackendMock.setShieldFundsUskMemoShieldingThresholdReturnValue(-1) + await rustBackendMock.setGetTransparentBalanceAccountReturnValue(0) + await rustBackendMock.setGetVerifiedBalanceAccountReturnValue(0) + await rustBackendMock.setListTransparentReceiversAccountReturnValue([]) + await rustBackendMock.setGetCurrentAddressAccountThrowableError(KeyDerivationErrors.unableToDerive) + await rustBackendMock.setGetNextAvailableAddressAccountThrowableError(KeyDerivationErrors.unableToDerive) + await rustBackendMock.setShieldFundsUskMemoShieldingThresholdReturnValue(-1) + await rustBackendMock.setCreateAccountSeedThrowableError(KeyDerivationErrors.unableToDerive) + await rustBackendMock.setGetReceivedMemoIdNoteReturnValue(nil) + await rustBackendMock.setGetSentMemoIdNoteReturnValue(nil) + await rustBackendMock.setCreateToAddressUskToValueMemoReturnValue(-1) + await rustBackendMock.setInitDataDbSeedReturnValue(.seedRequired) + await rustBackendMock.setGetNearestRewindHeightHeightReturnValue(-1) + await rustBackendMock.setInitBlocksTableHeightHashTimeSaplingTreeClosure() { _, _, _, _ in } + await rustBackendMock.setPutUnspentTransparentOutputTxidIndexScriptValueHeightClosure() { _, _, _, _, _ in } + await rustBackendMock.setCreateToAddressUskToValueMemoReturnValue(-1) + await rustBackendMock.setCreateToAddressUskToValueMemoReturnValue(-1) + await rustBackendMock.setDecryptAndStoreTransactionTxBytesMinedHeightThrowableError(RustWeldingError.genericError(message: "mock fail")) - static func initBlockMetadataDb(fsBlockDbRoot: URL) async throws -> Bool { - true - } - - static func writeBlocksMetadata(fsBlockDbRoot: URL, blocks: [ZcashLightClientKit.ZcashCompactBlock]) async throws -> Bool { - try writeBlocksMetadataResult() - } - - static func initAccountsTable(dbData: URL, ufvks: [ZcashLightClientKit.UnifiedFullViewingKey], networkType: ZcashLightClientKit.NetworkType) async throws { } - - static func createToAddress(dbData: URL, usk: ZcashLightClientKit.UnifiedSpendingKey, to address: String, value: Int64, memo: ZcashLightClientKit.MemoBytes?, spendParamsPath: String, outputParamsPath: String, networkType: ZcashLightClientKit.NetworkType) async -> Int64 { - -1 - } - - static func shieldFunds( - dbData: URL, - usk: ZcashLightClientKit.UnifiedSpendingKey, - memo: ZcashLightClientKit.MemoBytes?, - shieldingThreshold: Zatoshi, - spendParamsPath: String, - outputParamsPath: String, - networkType: ZcashLightClientKit.NetworkType - ) async -> Int64 { - -1 - } - - static func getAddressMetadata(_ address: String) -> ZcashLightClientKit.AddressMetadata? { - nil - } - - static func clearUtxos(dbData: URL, address: ZcashLightClientKit.TransparentAddress, sinceHeight: ZcashLightClientKit.BlockHeight, networkType: ZcashLightClientKit.NetworkType) async throws -> Int32 { - 0 - } - - static func getTransparentBalance(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> Int64 { 0 } - - static func getVerifiedTransparentBalance(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> Int64 { 0 } - - static func listTransparentReceivers(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> [ZcashLightClientKit.TransparentAddress] { - [] - } - - static func deriveUnifiedFullViewingKey(from spendingKey: ZcashLightClientKit.UnifiedSpendingKey, networkType: ZcashLightClientKit.NetworkType) throws -> ZcashLightClientKit.UnifiedFullViewingKey { - throw KeyDerivationErrors.unableToDerive - } - - static func deriveUnifiedSpendingKey(from seed: [UInt8], accountIndex: Int32, networkType: ZcashLightClientKit.NetworkType) throws -> ZcashLightClientKit.UnifiedSpendingKey { - throw KeyDerivationErrors.unableToDerive - } - - static func getCurrentAddress(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> ZcashLightClientKit.UnifiedAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func getNextAvailableAddress(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> ZcashLightClientKit.UnifiedAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func getSaplingReceiver(for uAddr: ZcashLightClientKit.UnifiedAddress) throws -> ZcashLightClientKit.SaplingAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func getTransparentReceiver(for uAddr: ZcashLightClientKit.UnifiedAddress) throws -> ZcashLightClientKit.TransparentAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func shieldFunds(dbData: URL, usk: ZcashLightClientKit.UnifiedSpendingKey, memo: ZcashLightClientKit.MemoBytes, spendParamsPath: String, outputParamsPath: String, networkType: ZcashLightClientKit.NetworkType) async -> Int64 { - -1 - } - - static func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] { - throw KeyDerivationErrors.receiverNotFound - } - - static func createAccount(dbData: URL, seed: [UInt8], networkType: ZcashLightClientKit.NetworkType) async throws -> ZcashLightClientKit.UnifiedSpendingKey { - throw KeyDerivationErrors.unableToDerive - } - - static func getReceivedMemo(dbData: URL, idNote: Int64, networkType: ZcashLightClientKit.NetworkType) async -> ZcashLightClientKit.Memo? { nil } - - static func getSentMemo(dbData: URL, idNote: Int64, networkType: ZcashLightClientKit.NetworkType) async -> ZcashLightClientKit.Memo? { nil } - - static func createToAddress(dbData: URL, usk: ZcashLightClientKit.UnifiedSpendingKey, to address: String, value: Int64, memo: ZcashLightClientKit.MemoBytes, spendParamsPath: String, outputParamsPath: String, networkType: ZcashLightClientKit.NetworkType) async -> Int64 { - -1 - } - - static func initDataDb(dbData: URL, seed: [UInt8]?, networkType: ZcashLightClientKit.NetworkType) async throws -> ZcashLightClientKit.DbInitResult { - .seedRequired - } - - static func deriveSaplingAddressFromViewingKey(_ extfvk: ZcashLightClientKit.SaplingExtendedFullViewingKey, networkType: ZcashLightClientKit.NetworkType) throws -> ZcashLightClientKit.SaplingAddress { - throw RustWeldingError.unableToDeriveKeys - } - - static func isValidSaplingExtendedSpendingKey(_ key: String, networkType: ZcashLightClientKit.NetworkType) -> Bool { false } - - static func deriveSaplingExtendedFullViewingKeys(seed: [UInt8], accounts: Int32, networkType: ZcashLightClientKit.NetworkType) throws -> [ZcashLightClientKit.SaplingExtendedFullViewingKey]? { - nil - } - - static func isValidUnifiedAddress(_ address: String, networkType: ZcashLightClientKit.NetworkType) -> Bool { - false - } - - static func deriveSaplingExtendedFullViewingKey(_ spendingKey: SaplingExtendedSpendingKey, networkType: ZcashLightClientKit.NetworkType) throws -> ZcashLightClientKit.SaplingExtendedFullViewingKey? { - nil - } - - public func deriveViewingKeys(seed: [UInt8], numberOfAccounts: Int) throws -> [UnifiedFullViewingKey] { [] } - - static func getNearestRewindHeight(dbData: URL, height: Int32, networkType: NetworkType) async -> Int32 { -1 } - - static func network(dbData: URL, address: String, sinceHeight: BlockHeight, networkType: NetworkType) async throws -> Int32 { -1 } - - static func initAccountsTable(dbData: URL, ufvks: [UnifiedFullViewingKey], networkType: NetworkType) async throws -> Bool { false } - - static func putUnspentTransparentOutput( - dbData: URL, - txid: [UInt8], - index: Int, - script: [UInt8], - value: Int64, - height: BlockHeight, - networkType: NetworkType - ) async throws -> Bool { - false - } - - static func downloadedUtxoBalance(dbData: URL, address: String, networkType: NetworkType) async throws -> WalletBalance { - throw RustWeldingError.genericError(message: "unimplemented") - } - - static func createToAddress( - dbData: URL, - account: Int32, - extsk: String, - to address: String, - value: Int64, - memo: String?, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 { - -1 - } - - static func deriveTransparentAddressFromSeed(seed: [UInt8], account: Int, index: Int, networkType: NetworkType) throws -> TransparentAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func deriveUnifiedFullViewingKeyFromSeed(_ seed: [UInt8], numberOfAccounts: Int32, networkType: NetworkType) throws -> [UnifiedFullViewingKey] { - throw KeyDerivationErrors.unableToDerive - } - - static func isValidSaplingExtendedFullViewingKey(_ key: String, networkType: NetworkType) -> Bool { false } - - static func isValidUnifiedFullViewingKey(_ ufvk: String, networkType: NetworkType) -> Bool { false } - - static func deriveSaplingExtendedSpendingKeys(seed: [UInt8], accounts: Int32, networkType: NetworkType) throws -> [SaplingExtendedSpendingKey]? { nil } - - static func consensusBranchIdFor(height: Int32, networkType: NetworkType) throws -> Int32 { - guard let consensus = consensusBranchID else { - return try rustBackend.consensusBranchIdFor(height: height, networkType: networkType) + await rustBackendMock.setInitDataDbSeedClosure() { seed in + return try await rustBackend.initDataDb(seed: seed) } - return consensus - } - static func lastError() -> RustWeldingError? { - mockError ?? rustBackend.lastError() - } - - static func getLastError() -> String? { - mockLastError ?? rustBackend.getLastError() - } - - static func isValidSaplingAddress(_ address: String, networkType: NetworkType) -> Bool { - true - } - - static func isValidTransparentAddress(_ address: String, networkType: NetworkType) -> Bool { - true - } - - static func initDataDb(dbData: URL, networkType: NetworkType) async throws { - if !mockDataDb { - _ = try await rustBackend.initDataDb(dbData: dbData, seed: nil, networkType: networkType) - } - } - - static func initBlocksTable( - dbData: URL, - height: Int32, - hash: String, - time: UInt32, - saplingTree: String, - networkType: NetworkType - ) async throws { - if !mockDataDb { + await rustBackendMock.setInitBlocksTableHeightHashTimeSaplingTreeClosure() { height, hash, time, saplingTree in try await rustBackend.initBlocksTable( - dbData: dbData, height: height, hash: hash, time: time, - saplingTree: saplingTree, - networkType: networkType + saplingTree: saplingTree ) } - } - - static func getBalance(dbData: URL, account: Int32, networkType: NetworkType) async throws -> Int64 { - if let balance = mockBalance { - return balance - } - return try await rustBackend.getBalance(dbData: dbData, account: account, networkType: networkType) - } - - static func getVerifiedBalance(dbData: URL, account: Int32, networkType: NetworkType) async throws -> Int64 { - if let balance = mockVerifiedBalance { - return balance + + await rustBackendMock.setGetBalanceAccountClosure() { account in + return try await rustBackend.getBalance(account: account) } - return try await rustBackend.getVerifiedBalance(dbData: dbData, account: account, networkType: networkType) - } - - static func validateCombinedChain(fsBlockDbRoot: URL, dbData: URL, networkType: NetworkType, limit: UInt32 = 0) async -> Int32 { - if let rate = self.mockValidateCombinedChainSuccessRate { - if shouldSucceed(successRate: rate) { - return await validationResult(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) - } else { - return Int32(mockValidateCombinedChainFailureHeight) - } - } else if let attempts = self.mockValidateCombinedChainFailAfterAttempts { - self.mockValidateCombinedChainFailAfterAttempts = attempts - 1 - if attempts > 0 { - return await validationResult(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) - } else { - if attempts == 0 { - return Int32(mockValidateCombinedChainFailureHeight) - } else if attempts < 0 && mockValidateCombinedChainKeepFailing { - return Int32(mockValidateCombinedChainFailureHeight) + await rustBackendMock.setGetVerifiedBalanceAccountClosure() { account in + return try await rustBackend.getVerifiedBalance(account: account) + } + + await rustBackendMock.setValidateCombinedChainLimitClosure() { [weak self] limit in + guard let self else { throw RustWeldingError.genericError(message: "Self is nil") } + if let rate = mockValidateCombinedChainSuccessRate { + if Self.shouldSucceed(successRate: rate) { + return try await rustBackend.validateCombinedChain(limit: limit) } else { - return await validationResult(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) + throw mockValidateCombinedChainFailureError + } + } else if let attempts = self.mockValidateCombinedChainFailAfterAttempts { + self.mockValidateCombinedChainFailAfterAttempts = attempts - 1 + if attempts > 0 { + return try await rustBackend.validateCombinedChain(limit: limit) + } else { + if attempts == 0 { + throw mockValidateCombinedChainFailureError + } else if attempts < 0 && mockValidateCombinedChainKeepFailing { + throw mockValidateCombinedChainFailureError + } else { + return try await rustBackend.validateCombinedChain(limit: limit) + } } - } - } - return await rustBackend.validateCombinedChain(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) - } - - private static func validationResult(fsBlockDbRoot: URL, dbData: URL, networkType: NetworkType) async -> Int32 { - if mockDataDb { - return -1 - } else { - return await rustBackend.validateCombinedChain(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) - } - } - - static func rewindToHeight(dbData: URL, height: Int32, networkType: NetworkType) async -> Bool { - mockDataDb ? true : rustBackend.rewindToHeight(dbData: dbData, height: height, networkType: networkType) - } - - static func scanBlocks(fsBlockDbRoot: URL, dbData: URL, limit: UInt32, networkType: NetworkType) async -> Bool { - if let rate = mockScanblocksSuccessRate { - if shouldSucceed(successRate: rate) { - return mockDataDb ? true : await rustBackend.scanBlocks(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) } else { - return false + return try await rustBackend.validateCombinedChain(limit: limit) } } - return await rustBackend.scanBlocks(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: Self.networkType) + + await rustBackendMock.setRewindToHeightHeightClosure() { height in + try await rustBackend.rewindToHeight(height: height) + } + + await rustBackendMock.setRewindCacheToHeightHeightClosure() { _ in } + + await rustBackendMock.setScanBlocksLimitClosure() { limit in + try await rustBackend.scanBlocks(limit: limit) + } } - - static func createToAddress( - dbData: URL, - account: Int32, - extsk: String, - consensusBranchId: Int32, - to address: String, - value: Int64, - memo: String?, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 { - -1 - } - - static func shouldSucceed(successRate: Float) -> Bool { + + private static func shouldSucceed(successRate: Float) -> Bool { let random = Float.random(in: 0.0...1.0) return random <= successRate } - - static func deriveExtendedFullViewingKey(_ spendingKey: String, networkType: NetworkType) throws -> String? { - nil - } - - static func deriveExtendedFullViewingKeys(seed: String, accounts: Int32, networkType: NetworkType) throws -> [String]? { - nil - } - - static func deriveExtendedSpendingKeys(seed: String, accounts: Int32, networkType: NetworkType) throws -> [String]? { - nil - } - - static func decryptAndStoreTransaction(dbData: URL, txBytes: [UInt8], minedHeight: Int32, networkType: NetworkType) async -> Bool { - false - } } extension SaplingParamsSourceURL { diff --git a/Tests/TestUtils/SynchronizerMock.swift b/Tests/TestUtils/SynchronizerMock.swift deleted file mode 100644 index 2f168535..00000000 --- a/Tests/TestUtils/SynchronizerMock.swift +++ /dev/null @@ -1,174 +0,0 @@ -// -// SynchronizerMock.swift -// -// -// Created by Michal Fousek on 20.03.2023. -// - -import Combine -import Foundation -@testable import ZcashLightClientKit - -class SynchronizerMock: Synchronizer { - init() { } - - var underlyingAlias: ZcashSynchronizerAlias! = nil - var alias: ZcashLightClientKit.ZcashSynchronizerAlias { underlyingAlias } - - var underlyingStateStream: AnyPublisher! = nil - var stateStream: AnyPublisher { underlyingStateStream } - - var underlyingLatestState: SynchronizerState! = nil - var latestState: SynchronizerState { underlyingLatestState } - - var underlyingEventStream: AnyPublisher! = nil - var eventStream: AnyPublisher { underlyingEventStream } - - var underlyingConnectionState: ConnectionState! = nil - var connectionState: ConnectionState { underlyingConnectionState } - - let metrics = SDKMetrics() - - var prepareWithSeedViewingKeysWalletBirthdayClosure: ( - ([UInt8]?, [UnifiedFullViewingKey], BlockHeight) async throws -> Initializer.InitializationResult - )! = nil - func prepare( - with seed: [UInt8]?, - viewingKeys: [UnifiedFullViewingKey], - walletBirthday: BlockHeight - ) async throws -> Initializer.InitializationResult { - return try await prepareWithSeedViewingKeysWalletBirthdayClosure(seed, viewingKeys, walletBirthday) - } - - var startRetryClosure: ((Bool) async throws -> Void)! = nil - func start(retry: Bool) async throws { - try await startRetryClosure(retry) - } - - var stopClosure: (() async -> Void)! = nil - func stop() async { - await stopClosure() - } - - var getSaplingAddressAccountIndexClosure: ((Int) async throws -> SaplingAddress)! = nil - func getSaplingAddress(accountIndex: Int) async throws -> SaplingAddress { - return try await getSaplingAddressAccountIndexClosure(accountIndex) - } - - var getUnifiedAddressAccountIndexClosure: ((Int) async throws -> UnifiedAddress)! = nil - func getUnifiedAddress(accountIndex: Int) async throws -> UnifiedAddress { - return try await getUnifiedAddressAccountIndexClosure(accountIndex) - } - - var getTransparentAddressAccountIndexClosure: ((Int) async throws -> TransparentAddress)! = nil - func getTransparentAddress(accountIndex: Int) async throws -> TransparentAddress { - return try await getTransparentAddressAccountIndexClosure(accountIndex) - } - - var sendToAddressSpendingKeyZatoshiToAddressMemoClosure: ( - (UnifiedSpendingKey, Zatoshi, Recipient, Memo?) async throws -> PendingTransactionEntity - )! = nil - func sendToAddress(spendingKey: UnifiedSpendingKey, zatoshi: Zatoshi, toAddress: Recipient, memo: Memo?) async throws -> PendingTransactionEntity { - return try await sendToAddressSpendingKeyZatoshiToAddressMemoClosure(spendingKey, zatoshi, toAddress, memo) - } - - var shieldFundsSpendingKeyMemoShieldingThresholdClosure: ((UnifiedSpendingKey, Memo, Zatoshi) async throws -> PendingTransactionEntity)! = nil - func shieldFunds(spendingKey: UnifiedSpendingKey, memo: Memo, shieldingThreshold: Zatoshi) async throws -> PendingTransactionEntity { - return try await shieldFundsSpendingKeyMemoShieldingThresholdClosure(spendingKey, memo, shieldingThreshold) - } - - var cancelSpendTransactionClosure: ((PendingTransactionEntity) async -> Bool)! = nil - func cancelSpend(transaction: PendingTransactionEntity) async -> Bool { - return await cancelSpendTransactionClosure(transaction) - } - - var underlyingPendingTransactions: [PendingTransactionEntity]! = nil - var pendingTransactions: [PendingTransactionEntity] { - get async { underlyingPendingTransactions } - } - - var underlyingClearedTransactions: [ZcashTransaction.Overview]! = nil - var clearedTransactions: [ZcashTransaction.Overview] { - get async { underlyingClearedTransactions } - } - - var underlyingSentTransactions: [ZcashTransaction.Sent]! = nil - var sentTransactions: [ZcashTransaction.Sent] { - get async { underlyingSentTransactions } - } - - var underlyingReceivedTransactions: [ZcashTransaction.Received]! = nil - var receivedTransactions: [ZcashTransaction.Received] { - get async { underlyingReceivedTransactions } - } - - var paginatedTransactionsOfKindClosure: ((TransactionKind) -> PaginatedTransactionRepository)! = nil - func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository { - return paginatedTransactionsOfKindClosure(kind) - } - - var getMemosForTransactionClosure: ((ZcashTransaction.Overview) async throws -> [Memo])! = nil - func getMemos(for transaction: ZcashTransaction.Overview) async throws -> [Memo] { - return try await getMemosForTransactionClosure(transaction) - } - - var getMemosForReceivedTransactionClosure: ((ZcashTransaction.Received) async throws -> [Memo])! = nil - func getMemos(for receivedTransaction: ZcashTransaction.Received) async throws -> [Memo] { - return try await getMemosForReceivedTransactionClosure(receivedTransaction) - } - - var getMemosForSentTransactionClosure: ((ZcashTransaction.Sent) async throws -> [Memo])! = nil - func getMemos(for sentTransaction: ZcashTransaction.Sent) async throws -> [Memo] { - return try await getMemosForSentTransactionClosure(sentTransaction) - } - - var getRecipientsForClearedTransactionClosure: ((ZcashTransaction.Overview) async -> [TransactionRecipient])! = nil - func getRecipients(for transaction: ZcashTransaction.Overview) async -> [TransactionRecipient] { - return await getRecipientsForClearedTransactionClosure(transaction) - } - - var getRecipientsForSentTransactionClosure: ((ZcashTransaction.Sent) async -> [TransactionRecipient])! = nil - func getRecipients(for transaction: ZcashTransaction.Sent) async -> [TransactionRecipient] { - return await getRecipientsForSentTransactionClosure(transaction) - } - - var allConfirmedTransactionsFromTransactionClosure: ((ZcashTransaction.Overview, Int) async throws -> [ZcashTransaction.Overview])! = nil - func allConfirmedTransactions(from transaction: ZcashTransaction.Overview, limit: Int) async throws -> [ZcashTransaction.Overview] { - return try await allConfirmedTransactionsFromTransactionClosure(transaction, limit) - } - - var latestHeightClosure: (() async throws -> BlockHeight)! = nil - func latestHeight() async throws -> BlockHeight { - return try await latestHeightClosure() - } - - var refreshUTXOsAddressFromHeightClosure: ((TransparentAddress, BlockHeight) async throws -> RefreshedUTXOs)! = nil - func refreshUTXOs(address: TransparentAddress, from height: BlockHeight) async throws -> RefreshedUTXOs { - return try await refreshUTXOsAddressFromHeightClosure(address, height) - } - - var getTransparentBalanceAccountIndexClosure: ((Int) async throws -> WalletBalance)! = nil - func getTransparentBalance(accountIndex: Int) async throws -> WalletBalance { - return try await getTransparentBalanceAccountIndexClosure(accountIndex) - } - - var getShieldedBalanceAccountIndexClosure: ((Int) async throws -> Zatoshi)! = nil - func getShieldedBalance(accountIndex: Int) async throws -> Zatoshi { - try await getShieldedBalanceAccountIndexClosure(accountIndex) - } - - var getShieldedVerifiedBalanceAccountIndexClosure: ((Int) async throws -> Zatoshi)! = nil - func getShieldedVerifiedBalance(accountIndex: Int) async throws -> Zatoshi { - try await getShieldedVerifiedBalanceAccountIndexClosure(accountIndex) - } - - var rewindPolicyClosure: ((RewindPolicy) -> AnyPublisher)! = nil - func rewind(_ policy: RewindPolicy) -> AnyPublisher { - return rewindPolicyClosure(policy) - } - - var wipeClosure: (() -> AnyPublisher)! = nil - func wipe() -> AnyPublisher { - return wipeClosure() - } -} diff --git a/Tests/TestUtils/TestCoordinator.swift b/Tests/TestUtils/TestCoordinator.swift index a37d1ed2..6c50ce06 100644 --- a/Tests/TestUtils/TestCoordinator.swift +++ b/Tests/TestUtils/TestCoordinator.swift @@ -52,56 +52,19 @@ class TestCoordinator { singleCallTimeoutInMillis: 10000, streamingCallTimeoutInMillis: 1000000 ) - - convenience init( + + init( alias: ZcashSynchronizerAlias = .default, walletBirthday: BlockHeight, network: ZcashNetwork, callPrepareInConstructor: Bool = true, endpoint: LightWalletEndpoint = TestCoordinator.defaultEndpoint, syncSessionIDGenerator: SyncSessionIDGenerator = UniqueSyncSessionIDGenerator() - ) async throws { - let derivationTool = DerivationTool(networkType: network.networkType) - - let spendingKey = try derivationTool.deriveUnifiedSpendingKey( - seed: Environment.seedBytes, - accountIndex: 0 - ) - - let ufvk = try derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) - - try await self.init( - alias: alias, - spendingKey: spendingKey, - unifiedFullViewingKey: ufvk, - walletBirthday: walletBirthday, - network: network, - callPrepareInConstructor: callPrepareInConstructor, - endpoint: endpoint, - syncSessionIDGenerator: syncSessionIDGenerator - ) - } - - required init( - alias: ZcashSynchronizerAlias = .default, - spendingKey: UnifiedSpendingKey, - unifiedFullViewingKey: UnifiedFullViewingKey, - walletBirthday: BlockHeight, - network: ZcashNetwork, - callPrepareInConstructor: Bool = true, - endpoint: LightWalletEndpoint = TestCoordinator.defaultEndpoint, - syncSessionIDGenerator: SyncSessionIDGenerator ) async throws { await InternalSyncProgress(alias: alias, storage: UserDefaults.standard, logger: logger).rewind(to: 0) - self.spendingKey = spendingKey - self.viewingKey = unifiedFullViewingKey - self.birthday = walletBirthday - self.databases = TemporaryDbBuilder.build() - self.network = network - - let liveService = LightWalletServiceFactory(endpoint: endpoint).make() - self.service = DarksideWalletService(endpoint: endpoint, service: liveService) + let databases = TemporaryDbBuilder.build() + self.databases = databases let initializer = Initializer( cacheDbURL: nil, @@ -117,9 +80,21 @@ class TestCoordinator { logLevel: .debug ) - let synchronizer = SDKSynchronizer(initializer: initializer, sessionGenerator: syncSessionIDGenerator, sessionTicker: .live) - - self.synchronizer = synchronizer + let derivationTool = initializer.makeDerivationTool() + + self.spendingKey = try await derivationTool.deriveUnifiedSpendingKey( + seed: Environment.seedBytes, + accountIndex: 0 + ) + + self.viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) + self.birthday = walletBirthday + self.network = network + + let liveService = LightWalletServiceFactory(endpoint: endpoint).make() + self.service = DarksideWalletService(endpoint: endpoint, service: liveService) + + self.synchronizer = SDKSynchronizer(initializer: initializer, sessionGenerator: syncSessionIDGenerator, sessionTicker: .live) subscribeToState(synchronizer: self.synchronizer) if callPrepareInConstructor { diff --git a/Tests/TestUtils/TestDbBuilder.swift b/Tests/TestUtils/TestDbBuilder.swift index 587fc8e7..62d201bb 100644 --- a/Tests/TestUtils/TestDbBuilder.swift +++ b/Tests/TestUtils/TestDbBuilder.swift @@ -58,16 +58,12 @@ enum TestDbBuilder { Bundle.module.url(forResource: "darkside_caches", withExtension: "db") } - static func prepopulatedDataDbProvider() async throws -> ConnectionProvider? { + static func prepopulatedDataDbProvider(rustBackend: ZcashRustBackendWelding) async throws -> ConnectionProvider? { guard let url = prePopulatedMainnetDataDbURL() else { return nil } let provider = SimpleConnectionProvider(path: url.absoluteString, readonly: true) - let initResult = try await ZcashRustBackend.initDataDb( - dbData: url, - seed: Environment.seedBytes, - networkType: .mainnet - ) + let initResult = try await rustBackend.initDataDb(seed: Environment.seedBytes) switch initResult { case .success: return provider @@ -76,19 +72,19 @@ enum TestDbBuilder { } } - static func transactionRepository() async throws -> TransactionRepository? { - guard let provider = try await prepopulatedDataDbProvider() else { return nil } + static func transactionRepository(rustBackend: ZcashRustBackendWelding) async throws -> TransactionRepository? { + guard let provider = try await prepopulatedDataDbProvider(rustBackend: rustBackend) else { return nil } return TransactionSQLDAO(dbProvider: provider) } - static func sentNotesRepository() async throws -> SentNotesRepository? { - guard let provider = try await prepopulatedDataDbProvider() else { return nil } + static func sentNotesRepository(rustBackend: ZcashRustBackendWelding) async throws -> SentNotesRepository? { + guard let provider = try await prepopulatedDataDbProvider(rustBackend: rustBackend) else { return nil } return SentNotesSQLDAO(dbProvider: provider) } - static func receivedNotesRepository() async throws -> ReceivedNoteRepository? { - guard let provider = try await prepopulatedDataDbProvider() else { return nil } + static func receivedNotesRepository(rustBackend: ZcashRustBackendWelding) async throws -> ReceivedNoteRepository? { + guard let provider = try await prepopulatedDataDbProvider(rustBackend: rustBackend) else { return nil } return ReceivedNotesSQLDAO(dbProvider: provider) } diff --git a/Tests/TestUtils/Tests+Utils.swift b/Tests/TestUtils/Tests+Utils.swift index 0312b13b..961f689f 100644 --- a/Tests/TestUtils/Tests+Utils.swift +++ b/Tests/TestUtils/Tests+Utils.swift @@ -9,10 +9,10 @@ import Combine import Foundation import GRPC -import ZcashLightClientKit import XCTest import NIO import NIOTransportServices +@testable import ZcashLightClientKit enum Environment { static let lightwalletdKey = "LIGHTWALLETD_ADDRESS" @@ -28,6 +28,11 @@ enum Environment { } static let testRecipientAddress = "zs17mg40levjezevuhdp5pqrd52zere7r7vrjgdwn5sj4xsqtm20euwahv9anxmwr3y3kmwuz8k55a" + + static var uniqueTestTempDirectory: URL { + URL(fileURLWithPath: NSString(string: NSTemporaryDirectory()) + .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) + } } public enum Constants { @@ -128,3 +133,21 @@ func parametersReady() -> Bool { return true } + +extension ZcashRustBackend { + static func makeForTests( + dbData: URL = try! __dataDbURL(), + fsBlockDbRoot: URL, + spendParamsPath: URL = SaplingParamsSourceURL.default.spendParamFileURL, + outputParamsPath: URL = SaplingParamsSourceURL.default.outputParamFileURL, + networkType: NetworkType + ) -> ZcashRustBackendWelding { + ZcashRustBackend( + dbData: dbData, + fsBlockDbRoot: fsBlockDbRoot, + spendParamsPath: spendParamsPath, + outputParamsPath: outputParamsPath, + networkType: networkType + ) + } +} diff --git a/Tests/TestUtils/AlternativeSynchronizerAPITestsData.swift b/Tests/TestUtils/TestsData.swift similarity index 64% rename from Tests/TestUtils/AlternativeSynchronizerAPITestsData.swift rename to Tests/TestUtils/TestsData.swift index 9838a881..74ba8f16 100644 --- a/Tests/TestUtils/AlternativeSynchronizerAPITestsData.swift +++ b/Tests/TestUtils/TestsData.swift @@ -1,5 +1,5 @@ // -// AlternativeSynchronizerAPITestsData.swift +// TestsData.swift // // // Created by Michal Fousek on 20.03.2023. @@ -8,13 +8,29 @@ import Foundation @testable import ZcashLightClientKit -class AlternativeSynchronizerAPITestsData { - let derivationTools = DerivationTool(networkType: .testnet) +class TestsData { + let networkType: NetworkType + + lazy var initialier = { + Initializer( + cacheDbURL: nil, + fsBlockDbRoot: URL(fileURLWithPath: "/"), + dataDbURL: URL(fileURLWithPath: "/"), + pendingDbURL: URL(fileURLWithPath: "/"), + endpoint: LightWalletEndpointBuilder.default, + network: ZcashNetworkBuilder.network(for: networkType), + spendParamsURL: URL(fileURLWithPath: "/"), + outputParamsURL: URL(fileURLWithPath: "/"), + saplingParamsSourceURL: .default + ) + }() + lazy var derivationTools: DerivationTool = { initialier.makeDerivationTool() }() let saplingAddress = SaplingAddress(validatedEncoding: "ztestsapling1ctuamfer5xjnnrdr3xdazenljx0mu0gutcf9u9e74tr2d3jwjnt0qllzxaplu54hgc2tyjdc2p6") let unifiedAddress = UnifiedAddress( validatedEncoding: """ u1l9f0l4348negsncgr9pxd9d3qaxagmqv3lnexcplmufpq7muffvfaue6ksevfvd7wrz7xrvn95rc5zjtn7ugkmgh5rnxswmcj30y0pw52pn0zjvy38rn2esfgve64rj5pcmazxgpyuj - """ + """, + networkType: .testnet ) let transparentAddress = TransparentAddress(validatedEncoding: "t1dRJRY7GmyeykJnMH38mdQoaZtFhn1QmGz") lazy var pendingTransactionEntity = { @@ -73,9 +89,19 @@ class AlternativeSynchronizerAPITestsData { }() var seed: [UInt8] = Environment.seedBytes - lazy var spendingKey: UnifiedSpendingKey = { try! derivationTools.deriveUnifiedSpendingKey(seed: seed, accountIndex: 0) }() - lazy var viewingKey: UnifiedFullViewingKey = { try! derivationTools.deriveUnifiedFullViewingKey(from: spendingKey) }() + var spendingKey: UnifiedSpendingKey { + get async { + try! await derivationTools.deriveUnifiedSpendingKey(seed: seed, accountIndex: 0) + } + } + var viewingKey: UnifiedFullViewingKey { + get async { + try! await derivationTools.deriveUnifiedFullViewingKey(from: spendingKey) + } + } var birthday: BlockHeight = 123000 - init() { } + init(networkType: NetworkType) { + self.networkType = networkType + } }