From 2bbc5800bd78de37827aec935879329bb25dd914 Mon Sep 17 00:00:00 2001 From: Michal Fousek Date: Fri, 31 Mar 2023 19:10:35 +0200 Subject: [PATCH] [#888] Make actor from ZcashRustBackendWelding Closes #888. - `ZcashRustBackend` is actor now. So majority of methods in this actor are now async. - Some methods stayed `static` in `ZcashRustBackend`. It would be hard to pass instance of the `ZcashRustBackend` to the places where these methods are used in static manner. And it would change lot of APIs. But it isn't problem from technical perspective because these methods would be `nonisolated` otherwise. - Methods `lastError()` and `getLastError()` in `ZcashRustBackend` are now private. This makes sure that ther won't be aby race condition between other methods and these two error methods. - All the methods for which was `lastError()` used in code now throw error. So `lastError()` is no longer needed outside of the `ZcashRustBackend`. - There are in the public API related to `DerivationTool`. - `DerivationTool` now requires instance of the `ZcashRustBackend`. And `ZcashRustBackend` isn't public type. So `DerivationTool` doesn't have any public constructor now. It can be created only via `Initializer.makeDerivationTool()` instance method. - `deriveUnifiedSpendingKey()` and `deriveUnifiedFullViewingKey()` in `DerivationTool` are now async. It is because these are using `ZcashRustBackend` inside. `DerivationTool` offers alternative (closure and combine) APIs. But downside is that there is no sync API to dervie spending key or viewing key. - Some methods of the `DerivationTool` are now static. These methods don't use anything that requires instance of the `DerivationTool` inside. [#888] Use Sourcery to generate mocks - I wrote mock for `Synchronizer` manually. And it's tedious and long and boring work. - Now `ZcashRustBackendWelding` is changed a lot so it means `MockRustBackend` must be changed a lot. So I decided to introduce `sourcery` to generate mocks from protocols so we don't have to do it manually ever. - To generate mocks go to `ZcashLightClientKit/Tests/TestUtils/Sourcery` directory and run `generateMocks.sh` script. - Your protocol must be mentioned in `AutoMockable.swift` file. Generated mocks are in `AutoMockable.generated.swift` file. [#888] Fix Offline tests - Offline tests target now runs and tests are green. - There is log of changes in tests. But logic is not changed. - Updated `AutoMockable.stencil` so sourcery is able to generate mock as actor when protocol is marked with: `// sourcery: mockActor`. - Last few updates in `ZcashRustBackendWelding`. In previous PR `rewindCacheToHeight` methods was overlooked and it didn't throw error. - Removed `MockRustBackend` and using generated `ZCashRustBackendWeldingMock` instead. - Using generated `SynchronizerMock`. [#888] Fix NetworkTests - Changed a bit how rust backend mock is used in the tests. Introduced `RustBackendMockHelper`. There are some state variables that must be preserved within one instance of the mock. This helper does exactly this. It keeps this state variables in the memory and helping mock to work as expected. [#888] Fix Darkside tests Create ZcashKeyDeriving internal protocol Use New DerivationTool that does not require RustBackend Remove duplicated methods that had been copied over [#888] Fix potentially broken tests I broke the tests because I moved `testTempDirectory` from each `TestCase` to the `Environment`. By this I caused that each tests uses exactly same URL. Which is directly against purpose of `testTempDirectory`. So now each test calls this one and store it to local variable. So each test has unique URL. [#888] Add ability to mock nonisolated methods to AutoMockable.stencil [#888] Add changelog and fix the documentation in ZcashRustBackendWelding [#888] Rename derivation rust backend protocol and remove static methods - Renamed `ZcashKeyDeriving` to `ZcashKeyDerivationBackendWelding`. So the naming scheme is same as for `ZcashRustBackendWelding`. - `ZcashKeyDerivationBackend` is now struct instead of enum. - Methods in `ZcashKeyDerivationBackendWelding` (except one) are no longer static. Because of this the respective methods in `DerivationTool` aren't also static anymore. --- .gitignore | 1 - .swiftlint.yml | 1 + .swiftlint_tests.yml | 1 + CHANGELOG.md | 8 + .../xcshareddata/xcschemes/All.xcscheme | 14 + .../ZcashLightClientSample/AppDelegate.swift | 10 +- .../Get UTXOs/GetUTXOsViewController.swift | 30 +- .../Send/SendViewController.swift | 26 +- .../SyncBlocksListViewController.swift | 46 +- Package.swift | 4 +- .../Block/CompactBlockProcessor.swift | 119 +- .../Block/Enhance/BlockEnhancer.swift | 24 +- .../FetchUnspentTxOutputs/UTXOFetcher.swift | 24 +- .../FSCompactBlockRepository.swift | 25 +- .../SaplingParametersHandler.swift | 16 +- .../Block/Scan/BlockScanner.swift | 17 +- .../Block/Validate/BlockValidator.swift | 52 +- .../CombineSynchronizer.swift | 7 - .../DAO/PendingTransactionDao.swift | 2 +- Sources/ZcashLightClientKit/Initializer.swift | 36 +- .../Model/WalletTypes.swift | 11 +- .../Rust/ZcashKeyDerivationBackend.swift | 225 +++ .../ZcashKeyDerivationBackendWelding.swift | 82 ++ .../Rust/ZcashRustBackend.swift | 504 ++----- .../Rust/ZcashRustBackendWelding.swift | 370 +---- .../ZcashLightClientKit/Synchronizer.swift | 10 + .../Synchronizer/ClosureSDKSynchronizer.swift | 86 +- .../Synchronizer/CombineSDKSynchronizer.swift | 98 +- .../Synchronizer/SDKSynchronizer.swift | 13 +- .../Tool/DerivationTool.swift | 109 +- .../WalletTransactionEncoder.swift | 38 +- .../Utils/AsyncToClosureGateway.swift | 46 + .../Utils/AsyncToCombineGateway.swift | 59 + .../Utils/SpecificCombineTypes.swift | 16 + .../Utils/SyncSessionIDGenerator.swift | 1 - .../DarksideTests/BlockDownloaderTests.swift | 17 +- Tests/DarksideTests/RewindRescanTests.swift | 17 +- .../SynchronizerDarksideTests.swift | 5 +- .../TransactionEnhancementTests.swift | 48 +- Tests/NetworkTests/BlockScanTests.swift | 72 +- Tests/NetworkTests/BlockStreamingTest.swift | 25 +- .../CompactBlockProcessorTests.swift | 66 +- .../NetworkTests/CompactBlockReorgTests.swift | 82 +- Tests/NetworkTests/DownloadTests.swift | 19 +- .../BlockBatchValidationTests.swift | 70 +- .../ClosureSynchronizerOfflineTests.swift | 60 +- .../CombineSynchronizerOfflineTests.swift | 60 +- .../CompactBlockProcessorOfflineTests.swift | 16 +- .../CompactBlockRepositoryTests.swift | 22 +- .../DerivationToolMainnetTests.swift | 67 +- .../DerivationToolTestnetTests.swift | 57 +- Tests/OfflineTests/FsBlockStorageTests.swift | 75 +- Tests/OfflineTests/NotesRepositoryTests.swift | 5 +- Tests/OfflineTests/NullBytesTests.swift | 28 +- Tests/OfflineTests/RecipientTests.swift | 2 +- .../SynchronizerOfflineTests.swift | 9 +- .../TransactionRepositoryTests.swift | 3 +- .../OfflineTests/UnifiedTypecodesTests.swift | 13 +- Tests/OfflineTests/WalletTests.swift | 22 +- .../OfflineTests/ZcashRustBackendTests.swift | 70 +- .../PerformanceTests/SynchronizerTests.swift | 14 +- .../CompactBlockProcessorEventHandler.swift | 1 + Tests/TestUtils/Sourcery/AutoMockable.stencil | 165 +++ Tests/TestUtils/Sourcery/AutoMockable.swift | 18 + .../AutoMockable.generated.swift | 1302 +++++++++++++++++ Tests/TestUtils/Sourcery/generateMocks.sh | 22 + Tests/TestUtils/Stubs.swift | 414 ++---- Tests/TestUtils/SynchronizerMock.swift | 174 --- Tests/TestUtils/TestCoordinator.swift | 63 +- Tests/TestUtils/TestDbBuilder.swift | 20 +- Tests/TestUtils/Tests+Utils.swift | 25 +- ...izerAPITestsData.swift => TestsData.swift} | 40 +- 72 files changed, 3122 insertions(+), 2197 deletions(-) create mode 100644 Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackend.swift create mode 100644 Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackendWelding.swift create mode 100644 Sources/ZcashLightClientKit/Utils/AsyncToClosureGateway.swift create mode 100644 Sources/ZcashLightClientKit/Utils/AsyncToCombineGateway.swift create mode 100644 Sources/ZcashLightClientKit/Utils/SpecificCombineTypes.swift create mode 100644 Tests/TestUtils/Sourcery/AutoMockable.stencil create mode 100644 Tests/TestUtils/Sourcery/AutoMockable.swift create mode 100644 Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift create mode 100755 Tests/TestUtils/Sourcery/generateMocks.sh delete mode 100644 Tests/TestUtils/SynchronizerMock.swift rename Tests/TestUtils/{AlternativeSynchronizerAPITestsData.swift => TestsData.swift} (64%) diff --git a/.gitignore b/.gitignore index 3cf17d73..6ce73d07 100644 --- a/.gitignore +++ b/.gitignore @@ -76,7 +76,6 @@ Pods # do not commit generated libraries to this repo lib *.a -*.generated.swift env-vars.sh .vscode/ diff --git a/.swiftlint.yml b/.swiftlint.yml index e77b991a..f3f0a22a 100644 --- a/.swiftlint.yml +++ b/.swiftlint.yml @@ -15,6 +15,7 @@ excluded: - ZcashLightClientKitTests/Constants.generated.swift - build/ - docs/ + - Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift disabled_rules: - notification_center_detachment diff --git a/.swiftlint_tests.yml b/.swiftlint_tests.yml index 83611d47..aaa1c68f 100644 --- a/.swiftlint_tests.yml +++ b/.swiftlint_tests.yml @@ -11,6 +11,7 @@ excluded: - ZcashLightClientKitTests/Constants.generated.swift - build/ - docs/ + - Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift disabled_rules: - notification_center_detachment diff --git a/CHANGELOG.md b/CHANGELOG.md index 49fd4dd0..8500ad73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,12 @@ # unreleased +### [#888] Updates to layer between Swift and Rust + +This is mostly internal change. But it also touches the public API. + +`KeyDeriving` protocol is changed. And therefore `DerivationTool` is changed. `deriveUnifiedSpendingKey(seed:accountIndex:)` and +`deriveUnifiedFullViewingKey(from:)` methods are now async. `DerivationTool` offers alternatives for these methods. Alternatives are using either +closures or Combine. + ### [#469] ZcashRustBackendWelding to Async This is mostly internal change. But it also touches the public API. diff --git a/Example/ZcashLightClientSample/ZcashLightClientSample.xcodeproj/xcshareddata/xcschemes/All.xcscheme b/Example/ZcashLightClientSample/ZcashLightClientSample.xcodeproj/xcshareddata/xcschemes/All.xcscheme index 763fac8f..a8723e0a 100644 --- a/Example/ZcashLightClientSample/ZcashLightClientSample.xcodeproj/xcshareddata/xcschemes/All.xcscheme +++ b/Example/ZcashLightClientSample/ZcashLightClientSample.xcodeproj/xcshareddata/xcschemes/All.xcscheme @@ -104,6 +104,20 @@ ReferencedContainer = "container:../.."> + + + + 0 else { - let error = rustBackend.lastError() ?? RustWeldingError.genericError( - message: "unknown error getting nearest rewind height for height: \(height)" - ) + let nearestHeight: Int32 + do { + nearestHeight = try await rustBackend.getNearestRewindHeight(height: height) + } catch { await fail(error) return await context.completion(.failure(error)) } // FIXME: [#719] this should be done on the rust layer, https://github.com/zcash/ZcashLightClientKit/issues/719 let rewindHeight = max(Int32(nearestHeight - 1), Int32(config.walletBirthday)) - guard await rustBackend.rewindToHeight(dbData: config.dataDb, height: rewindHeight, networkType: self.config.network.networkType) else { - let error = rustBackend.lastError() ?? RustWeldingError.genericError(message: "unknown error rewinding to height \(height)") + + do { + try await rustBackend.rewindToHeight(height: rewindHeight) + } catch { await fail(error) return await context.completion(.failure(error)) } @@ -715,7 +694,7 @@ actor CompactBlockProcessor { func validateServer() async { do { let info = try await self.service.getInfo() - try Self.validateServerInfo( + try await Self.validateServerInfo( info, saplingActivation: self.config.saplingActivation, localNetwork: self.config.network, @@ -1055,14 +1034,10 @@ actor CompactBlockProcessor { self.consecutiveChainValidationErrors += 1 - let rewindResult = await rustBackend.rewindToHeight( - dbData: config.dataDb, - height: Int32(rewindHeight), - networkType: self.config.network.networkType - ) - - guard rewindResult else { - await fail(rustBackend.lastError() ?? RustWeldingError.genericError(message: "unknown error rewinding to height \(height)")) + do { + try await rustBackend.rewindToHeight(height: Int32(rewindHeight)) + } catch { + await fail(error) return } @@ -1205,11 +1180,7 @@ extension CompactBlockProcessor.State: Equatable { extension CompactBlockProcessor { func getUnifiedAddress(accountIndex: Int) async throws -> UnifiedAddress { - try await rustBackend.getCurrentAddress( - dbData: config.dataDb, - account: Int32(accountIndex), - networkType: config.network.networkType - ) + try await rustBackend.getCurrentAddress(account: Int32(accountIndex)) } func getSaplingAddress(accountIndex: Int) async throws -> SaplingAddress { @@ -1227,18 +1198,10 @@ extension CompactBlockProcessor { return WalletBalance( verified: Zatoshi( - try await rustBackend.getVerifiedTransparentBalance( - dbData: config.dataDb, - account: Int32(accountIndex), - networkType: config.network.networkType - ) + try await rustBackend.getVerifiedTransparentBalance(account: Int32(accountIndex)) ), total: Zatoshi( - try await rustBackend.getTransparentBalance( - dbData: config.dataDb, - account: Int32(accountIndex), - networkType: config.network.networkType - ) + try await rustBackend.getTransparentBalance(account: Int32(accountIndex)) ) ) } @@ -1269,19 +1232,15 @@ extension CompactBlockProcessor { var skipped: [UnspentTransactionOutputEntity] = [] for utxo in utxos { do { - if try await rustBackend.putUnspentTransparentOutput( - dbData: dataDb, + try await rustBackend.putUnspentTransparentOutput( txid: utxo.txid.bytes, index: utxo.index, script: utxo.script.bytes, value: Int64(utxo.valueZat), - height: utxo.height, - networkType: self.config.network.networkType - ) { - refreshed.append(utxo) - } else { - skipped.append(utxo) - } + height: utxo.height + ) + + refreshed.append(utxo) } catch { logger.info("failed to put utxo - error: \(error)") skipped.append(utxo) @@ -1389,7 +1348,7 @@ extension CompactBlockProcessor { downloaderService: BlockDownloaderService, transactionRepository: TransactionRepository, config: Configuration, - rustBackend: ZcashRustBackendWelding.Type, + rustBackend: ZcashRustBackendWelding, internalSyncProgress: InternalSyncProgress ) async throws -> CompactBlockProcessor.NextState { // It should be ok to not create new Task here because this method is already async. But for some reason something not good happens @@ -1397,7 +1356,7 @@ extension CompactBlockProcessor { let task = Task(priority: .userInitiated) { let info = try await service.getInfo() - try CompactBlockProcessor.validateServerInfo( + try await CompactBlockProcessor.validateServerInfo( info, saplingActivation: config.saplingActivation, localNetwork: config.network, diff --git a/Sources/ZcashLightClientKit/Block/Enhance/BlockEnhancer.swift b/Sources/ZcashLightClientKit/Block/Enhance/BlockEnhancer.swift index ad209a75..c5081534 100644 --- a/Sources/ZcashLightClientKit/Block/Enhance/BlockEnhancer.swift +++ b/Sources/ZcashLightClientKit/Block/Enhance/BlockEnhancer.swift @@ -14,20 +14,14 @@ enum BlockEnhancerError: Error { case txIdNotFound(txId: Data) } -struct BlockEnhancerConfig { - let dataDb: URL - let networkType: NetworkType -} - protocol BlockEnhancer { func enhance(at range: CompactBlockRange, didEnhance: (EnhancementProgress) async -> Void) async throws -> [ZcashTransaction.Overview] } struct BlockEnhancerImpl { let blockDownloaderService: BlockDownloaderService - let config: BlockEnhancerConfig let internalSyncProgress: InternalSyncProgress - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let transactionRepository: TransactionRepository let metrics: SDKMetrics let logger: Logger @@ -41,17 +35,13 @@ struct BlockEnhancerImpl { let block = String(describing: transaction.minedHeight) logger.debug("Decrypting and storing transaction id: \(transactionID) block: \(block)") - let decryptionResult = await rustBackend.decryptAndStoreTransaction( - dbData: config.dataDb, - txBytes: fetchedTransaction.raw.bytes, - minedHeight: Int32(fetchedTransaction.minedHeight), - networkType: config.networkType - ) - - guard decryptionResult else { - throw BlockEnhancerError.decryptError( - error: rustBackend.lastError() ?? .genericError(message: "`decryptAndStoreTransaction` failed. No message available") + do { + try await rustBackend.decryptAndStoreTransaction( + txBytes: fetchedTransaction.raw.bytes, + minedHeight: Int32(fetchedTransaction.minedHeight) ) + } catch { + throw BlockEnhancerError.decryptError(error: error) } let confirmedTx: ZcashTransaction.Overview diff --git a/Sources/ZcashLightClientKit/Block/FetchUnspentTxOutputs/UTXOFetcher.swift b/Sources/ZcashLightClientKit/Block/FetchUnspentTxOutputs/UTXOFetcher.swift index 80cc3096..dd3f9b1d 100644 --- a/Sources/ZcashLightClientKit/Block/FetchUnspentTxOutputs/UTXOFetcher.swift +++ b/Sources/ZcashLightClientKit/Block/FetchUnspentTxOutputs/UTXOFetcher.swift @@ -13,8 +13,6 @@ enum UTXOFetcherError: Error { } struct UTXOFetcherConfig { - let dataDb: URL - let networkType: NetworkType let walletBirthdayProvider: () async -> BlockHeight } @@ -27,7 +25,7 @@ struct UTXOFetcherImpl { let blockDownloaderService: BlockDownloaderService let config: UTXOFetcherConfig let internalSyncProgress: InternalSyncProgress - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let metrics: SDKMetrics let logger: Logger } @@ -41,11 +39,7 @@ extension UTXOFetcherImpl: UTXOFetcher { var tAddresses: [TransparentAddress] = [] for account in accounts { - tAddresses += try await rustBackend.listTransparentReceivers( - dbData: config.dataDb, - account: Int32(account), - networkType: config.networkType - ) + tAddresses += try await rustBackend.listTransparentReceivers(account: Int32(account)) } var utxos: [UnspentTransactionOutputEntity] = [] @@ -64,19 +58,15 @@ extension UTXOFetcherImpl: UTXOFetcher { let startTime = Date() for utxo in utxos { do { - if try await rustBackend.putUnspentTransparentOutput( - dbData: config.dataDb, + try await rustBackend.putUnspentTransparentOutput( txid: utxo.txid.bytes, index: utxo.index, script: utxo.script.bytes, value: Int64(utxo.valueZat), - height: utxo.height, - networkType: config.networkType - ) { - refreshed.append(utxo) - } else { - skipped.append(utxo) - } + height: utxo.height + ) + + refreshed.append(utxo) await internalSyncProgress.set(utxo.height, .latestUTXOFetchedHeight) } catch { diff --git a/Sources/ZcashLightClientKit/Block/FilesystemStorage/FSCompactBlockRepository.swift b/Sources/ZcashLightClientKit/Block/FilesystemStorage/FSCompactBlockRepository.swift index 3192242c..fe9432fe 100644 --- a/Sources/ZcashLightClientKit/Block/FilesystemStorage/FSCompactBlockRepository.swift +++ b/Sources/ZcashLightClientKit/Block/FilesystemStorage/FSCompactBlockRepository.swift @@ -52,7 +52,10 @@ extension FSCompactBlockRepository: CompactBlockRepository { try fileManager.createDirectory(at: blocksDirectory, withIntermediateDirectories: true) } - guard try await self.metadataStore.initFsBlockDbRoot(self.fsBlockDbRoot) else { + do { + try await self.metadataStore.initFsBlockDbRoot() + } catch { + logger.error("Blocks metadata store init failed with error: \(error)") throw CompactBlockRepositoryError.failedToInitializeCache } } @@ -210,12 +213,12 @@ extension FSBlockFileWriter { struct FSMetadataStore { var saveBlocksMeta: ([ZcashCompactBlock]) async throws -> Void var rewindToHeight: (BlockHeight) async throws -> Void - var initFsBlockDbRoot: (URL) async throws -> Bool + var initFsBlockDbRoot: () async throws -> Void var latestHeight: () async -> BlockHeight } extension FSMetadataStore { - static func live(fsBlockDbRoot: URL, rustBackend: ZcashRustBackendWelding.Type, logger: Logger) -> FSMetadataStore { + static func live(fsBlockDbRoot: URL, rustBackend: ZcashRustBackendWelding, logger: Logger) -> FSMetadataStore { FSMetadataStore { blocks in try await FSMetadataStore.saveBlocksMeta( blocks, @@ -224,13 +227,15 @@ extension FSMetadataStore { logger: logger ) } rewindToHeight: { height in - guard await rustBackend.rewindCacheToHeight(fsBlockDbRoot: fsBlockDbRoot, height: Int32(height)) else { + do { + try await rustBackend.rewindCacheToHeight(height: Int32(height)) + } catch { throw CompactBlockRepositoryError.failedToRewind(height) } - } initFsBlockDbRoot: { dbRootURL in - try await rustBackend.initBlockMetadataDb(fsBlockDbRoot: dbRootURL) + } initFsBlockDbRoot: { + try await rustBackend.initBlockMetadataDb() } latestHeight: { - await rustBackend.latestCachedBlockHeight(fsBlockDbRoot: fsBlockDbRoot) + await rustBackend.latestCachedBlockHeight() } } } @@ -244,15 +249,13 @@ extension FSMetadataStore { static func saveBlocksMeta( _ blocks: [ZcashCompactBlock], fsBlockDbRoot: URL, - rustBackend: ZcashRustBackendWelding.Type, + rustBackend: ZcashRustBackendWelding, logger: Logger ) async throws { guard !blocks.isEmpty else { return } do { - guard try await rustBackend.writeBlocksMetadata(fsBlockDbRoot: fsBlockDbRoot, blocks: blocks) else { - throw CompactBlockRepositoryError.failedToWriteMetadata - } + try await rustBackend.writeBlocksMetadata(blocks: blocks) } catch { logger.error("Failed to write metadata with error: \(error)") throw CompactBlockRepositoryError.failedToWriteMetadata diff --git a/Sources/ZcashLightClientKit/Block/SaplingParameters/SaplingParametersHandler.swift b/Sources/ZcashLightClientKit/Block/SaplingParameters/SaplingParametersHandler.swift index 08bb8d85..09746b6d 100644 --- a/Sources/ZcashLightClientKit/Block/SaplingParameters/SaplingParametersHandler.swift +++ b/Sources/ZcashLightClientKit/Block/SaplingParameters/SaplingParametersHandler.swift @@ -8,8 +8,6 @@ import Foundation struct SaplingParametersHandlerConfig { - let dataDb: URL - let networkType: NetworkType let outputParamsURL: URL let spendParamsURL: URL let saplingParamsSourceURL: SaplingParamsSourceURL @@ -21,7 +19,7 @@ protocol SaplingParametersHandler { struct SaplingParametersHandlerImpl { let config: SaplingParametersHandlerConfig - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let logger: Logger } @@ -30,16 +28,8 @@ extension SaplingParametersHandlerImpl: SaplingParametersHandler { try Task.checkCancellation() do { - let totalShieldedBalance = try await rustBackend.getBalance( - dbData: config.dataDb, - account: Int32(0), - networkType: config.networkType - ) - let totalTransparentBalance = try await rustBackend.getTransparentBalance( - dbData: config.dataDb, - account: Int32(0), - networkType: config.networkType - ) + let totalShieldedBalance = try await rustBackend.getBalance(account: Int32(0)) + let totalTransparentBalance = try await rustBackend.getTransparentBalance(account: Int32(0)) // Download Sapling parameters only if sapling funds are detected. guard totalShieldedBalance > 0 || totalTransparentBalance > 0 else { return } diff --git a/Sources/ZcashLightClientKit/Block/Scan/BlockScanner.swift b/Sources/ZcashLightClientKit/Block/Scan/BlockScanner.swift index f4f9ff34..a49ca4c2 100644 --- a/Sources/ZcashLightClientKit/Block/Scan/BlockScanner.swift +++ b/Sources/ZcashLightClientKit/Block/Scan/BlockScanner.swift @@ -8,8 +8,6 @@ import Foundation struct BlockScannerConfig { - let fsBlockCacheRoot: URL - let dataDB: URL let networkType: NetworkType let scanningBatchSize: Int } @@ -20,7 +18,7 @@ protocol BlockScanner { struct BlockScannerImpl { let config: BlockScannerConfig - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let transactionRepository: TransactionRepository let metrics: SDKMetrics let logger: Logger @@ -42,19 +40,16 @@ extension BlockScannerImpl: BlockScanner { let previousScannedHeight = lastScannedHeight // TODO: [#576] remove this arbitrary batch size https://github.com/zcash/ZcashLightClientKit/issues/576 - let batchSize = scanBatchSize(startScanHeight: previousScannedHeight + 1, network: self.config.networkType) + let batchSize = scanBatchSize(startScanHeight: previousScannedHeight + 1, network: config.networkType) let scanStartTime = Date() - guard await self.rustBackend.scanBlocks( - fsBlockDbRoot: config.fsBlockCacheRoot, - dbData: config.dataDB, - limit: batchSize, - networkType: config.networkType - ) else { - let error: Error = rustBackend.lastError() ?? CompactBlockProcessorError.unknown + do { + try await self.rustBackend.scanBlocks(limit: batchSize) + } catch { logger.debug("block scanning failed with error: \(String(describing: error))") throw error } + let scanFinishTime = Date() lastScannedHeight = try await transactionRepository.lastScannedHeight() diff --git a/Sources/ZcashLightClientKit/Block/Validate/BlockValidator.swift b/Sources/ZcashLightClientKit/Block/Validate/BlockValidator.swift index 9a78ff15..1936f621 100644 --- a/Sources/ZcashLightClientKit/Block/Validate/BlockValidator.swift +++ b/Sources/ZcashLightClientKit/Block/Validate/BlockValidator.swift @@ -14,20 +14,13 @@ enum BlockValidatorError: Error { case failedWithUnknownError } -struct BlockValidatorConfig { - let fsBlockCacheRoot: URL - let dataDB: URL - let networkType: NetworkType -} - protocol BlockValidator { /// Validate all the downloaded blocks that haven't been yet validated. func validate() async throws } struct BlockValidatorImpl { - let config: BlockValidatorConfig - let rustBackend: ZcashRustBackendWelding.Type + let rustBackend: ZcashRustBackendWelding let metrics: SDKMetrics let logger: Logger } @@ -37,14 +30,24 @@ extension BlockValidatorImpl: BlockValidator { try Task.checkCancellation() let startTime = Date() - let result = await rustBackend.validateCombinedChain( - fsBlockDbRoot: config.fsBlockCacheRoot, - dbData: config.dataDB, - networkType: config.networkType, - limit: 0 - ) - let finishTime = Date() + do { + try await rustBackend.validateCombinedChain(limit: 0) + pushProgressReport(startTime: startTime, finishTime: Date()) + logger.debug("validateChainFinished") + } catch { + pushProgressReport(startTime: startTime, finishTime: Date()) + switch error { + case let RustWeldingError.invalidChain(upperBound): + throw BlockValidatorError.validationFailed(height: BlockHeight(upperBound)) + + default: + throw BlockValidatorError.failedWithError(error) + } + } + } + + private func pushProgressReport(startTime: Date, finishTime: Date) { metrics.pushProgressReport( progress: BlockProgress(startHeight: 0, targetHeight: 0, progressHeight: 0), start: startTime, @@ -52,24 +55,5 @@ extension BlockValidatorImpl: BlockValidator { batchSize: 0, operation: .validateBlocks ) - - switch result { - case 0: - let rustError = rustBackend.lastError() - logger.debug("Block validation failed with error: \(String(describing: rustError))") - if let rustError { - throw BlockValidatorError.failedWithError(rustError) - } else { - throw BlockValidatorError.failedWithUnknownError - } - - case ZcashRustBackendWeldingConstants.validChain: - logger.debug("validateChainFinished") - return - - default: - logger.debug("Block validation failed at height: \(result)") - throw BlockValidatorError.validationFailed(height: BlockHeight(result)) - } } } diff --git a/Sources/ZcashLightClientKit/CombineSynchronizer.swift b/Sources/ZcashLightClientKit/CombineSynchronizer.swift index 5d35ceaa..c25b9399 100644 --- a/Sources/ZcashLightClientKit/CombineSynchronizer.swift +++ b/Sources/ZcashLightClientKit/CombineSynchronizer.swift @@ -8,13 +8,6 @@ import Combine import Foundation -/* These aliases are here to just make the API easier to read. */ - -// Publisher which emitts completed or error. No value is emitted. -public typealias CompletablePublisher = AnyPublisher -// Publisher that either emits one value and then finishes or it emits error. -public typealias SinglePublisher = AnyPublisher - /// This defines a Combine-based API for the SDK. It's expected that the implementation of this protocol is only a very thin layer that translates /// async API defined in `Synchronizer` to Combine-based API. And it doesn't do anything else. It's here so each client can choose the API that suits /// its case the best. diff --git a/Sources/ZcashLightClientKit/DAO/PendingTransactionDao.swift b/Sources/ZcashLightClientKit/DAO/PendingTransactionDao.swift index 17520ca1..a51ce521 100644 --- a/Sources/ZcashLightClientKit/DAO/PendingTransactionDao.swift +++ b/Sources/ZcashLightClientKit/DAO/PendingTransactionDao.swift @@ -28,7 +28,7 @@ struct PendingTransaction: PendingTransactionEntity, Decodable, Encodable { case rawTransactionId = "txid" case fee } - + var recipient: PendingTransactionRecipient var accountIndex: Int var minedHeight: BlockHeight diff --git a/Sources/ZcashLightClientKit/Initializer.swift b/Sources/ZcashLightClientKit/Initializer.swift index 7ec08cdc..5c07ec66 100644 --- a/Sources/ZcashLightClientKit/Initializer.swift +++ b/Sources/ZcashLightClientKit/Initializer.swift @@ -115,7 +115,6 @@ public class Initializer { // This is used to uniquely identify instance of the SDKSynchronizer. It's used when checking if the Alias is already used or not. let id = UUID() - let rustBackend: ZcashRustBackendWelding.Type let alias: ZcashSynchronizerAlias let endpoint: LightWalletEndpoint let fsBlockDbRoot: URL @@ -131,6 +130,7 @@ public class Initializer { let blockDownloaderService: BlockDownloaderService let network: ZcashNetwork let logger: Logger + let rustBackend: ZcashRustBackendWelding /// The effective birthday of the wallet based on the height provided when initializing and the checkpoints available on this SDK. /// @@ -180,8 +180,16 @@ public class Initializer { let (updatedURLs, parsingError) = Self.tryToUpdateURLs(with: alias, urls: urls) let logger = OSLogger(logLevel: logLevel, alias: alias) + let rustBackend = ZcashRustBackend( + dbData: updatedURLs.dataDbURL, + fsBlockDbRoot: updatedURLs.fsBlockDbRoot, + spendParamsPath: updatedURLs.spendParamsURL, + outputParamsPath: updatedURLs.outputParamsURL, + networkType: network.networkType + ) + self.init( - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, network: network, cacheDbURL: cacheDbURL, urls: updatedURLs, @@ -198,7 +206,7 @@ public class Initializer { fsBlockDbRoot: updatedURLs.fsBlockDbRoot, metadataStore: .live( fsBlockDbRoot: updatedURLs.fsBlockDbRoot, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -216,7 +224,7 @@ public class Initializer { /// /// !!! It's expected that URLs put here are already update with the Alias. init( - rustBackend: ZcashRustBackendWelding.Type, + rustBackend: ZcashRustBackendWelding, network: ZcashNetwork, cacheDbURL: URL?, urls: URLs, @@ -341,7 +349,7 @@ public class Initializer { } do { - if case .seedRequired = try await rustBackend.initDataDb(dbData: dataDbURL, seed: seed, networkType: network.networkType) { + if case .seedRequired = try await rustBackend.initDataDb(seed: seed) { return .seedRequired } } catch { @@ -351,12 +359,10 @@ public class Initializer { let checkpoint = Checkpoint.birthday(with: walletBirthday, network: network) do { try await rustBackend.initBlocksTable( - dbData: dataDbURL, height: Int32(checkpoint.height), hash: checkpoint.hash, time: checkpoint.time, - saplingTree: checkpoint.saplingTree, - networkType: network.networkType + saplingTree: checkpoint.saplingTree ) } catch RustWeldingError.dataDbNotEmpty { // this is fine @@ -367,11 +373,7 @@ public class Initializer { self.walletBirthday = checkpoint.height do { - try await rustBackend.initAccountsTable( - dbData: dataDbURL, - ufvks: viewingKeys, - networkType: network.networkType - ) + try await rustBackend.initAccountsTable(ufvks: viewingKeys) } catch RustWeldingError.dataDbNotEmpty { // this is fine } catch RustWeldingError.malformedStringInput { @@ -395,14 +397,18 @@ public class Initializer { checks if the provided address is a valid sapling address */ public func isValidSaplingAddress(_ address: String) -> Bool { - rustBackend.isValidSaplingAddress(address, networkType: network.networkType) + DerivationTool(networkType: network.networkType).isValidSaplingAddress(address) } /** checks if the provided address is a transparent zAddress */ public func isValidTransparentAddress(_ address: String) -> Bool { - rustBackend.isValidTransparentAddress(address, networkType: network.networkType) + DerivationTool(networkType: network.networkType).isValidTransparentAddress(address) + } + + public func makeDerivationTool() -> DerivationTool { + return DerivationTool(networkType: network.networkType) } } diff --git a/Sources/ZcashLightClientKit/Model/WalletTypes.swift b/Sources/ZcashLightClientKit/Model/WalletTypes.swift index 71647549..06e49060 100644 --- a/Sources/ZcashLightClientKit/Model/WalletTypes.swift +++ b/Sources/ZcashLightClientKit/Model/WalletTypes.swift @@ -63,7 +63,7 @@ public struct UnifiedFullViewingKey: Equatable, StringEncoded, Undescribable { /// - Throws: `KeyEncodingError.invalidEncoding`when the provided encoding is /// found to be invalid public init(encoding: String, account: UInt32, network: NetworkType) throws { - guard DerivationTool.rustwelding.isValidUnifiedFullViewingKey(encoding, networkType: network) else { + guard DerivationTool(networkType: network).isValidUnifiedFullViewingKey(encoding) else { throw KeyEncodingError.invalidEncoding } @@ -85,7 +85,7 @@ public struct SaplingExtendedFullViewingKey: Equatable, StringEncoded, Undescrib /// - Throws: `KeyEncodingError.invalidEncoding`when the provided encoding is /// found to be invalid public init(encoding: String, network: NetworkType) throws { - guard DerivationTool.rustwelding.isValidSaplingExtendedFullViewingKey(encoding, networkType: network) else { + guard ZcashKeyDerivationBackend(networkType: network).isValidSaplingExtendedFullViewingKey(encoding) else { throw KeyEncodingError.invalidEncoding } self.encoding = encoding @@ -174,6 +174,8 @@ public struct SaplingAddress: Equatable, StringEncoded { } public struct UnifiedAddress: Equatable, StringEncoded { + let networkType: NetworkType + public enum Errors: Error { case couldNotExtractTypecodes } @@ -212,6 +214,7 @@ public struct UnifiedAddress: Equatable, StringEncoded { /// - Throws: `KeyEncodingError.invalidEncoding`when the provided encoding is /// found to be invalid public init(encoding: String, network: NetworkType) throws { + networkType = network guard DerivationTool(networkType: network).isValidUnifiedAddress(encoding) else { throw KeyEncodingError.invalidEncoding } @@ -224,7 +227,7 @@ public struct UnifiedAddress: Equatable, StringEncoded { /// couldn't be extracted public func availableReceiverTypecodes() throws -> [UnifiedAddress.ReceiverTypecodes] { do { - return try DerivationTool.receiverTypecodesFromUnifiedAddress(self) + return try DerivationTool(networkType: networkType).receiverTypecodesFromUnifiedAddress(self) } catch { throw Errors.couldNotExtractTypecodes } @@ -272,7 +275,7 @@ public enum Recipient: Equatable, StringEncoded { metadata.networkType) case .p2sh: return (.transparent(TransparentAddress(validatedEncoding: encoded)), metadata.networkType) case .sapling: return (.sapling(SaplingAddress(validatedEncoding: encoded)), metadata.networkType) - case .unified: return (.unified(UnifiedAddress(validatedEncoding: encoded)), metadata.networkType) + case .unified: return (.unified(UnifiedAddress(validatedEncoding: encoded, networkType: metadata.networkType)), metadata.networkType) } } } diff --git a/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackend.swift b/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackend.swift new file mode 100644 index 00000000..195102ad --- /dev/null +++ b/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackend.swift @@ -0,0 +1,225 @@ +// +// ZcashKeyDerivationBackend.swift +// +// +// Created by Francisco Gindre on 4/7/23. +// + +import Foundation +import libzcashlc + +struct ZcashKeyDerivationBackend: ZcashKeyDerivationBackendWelding { + let networkType: NetworkType + + init(networkType: NetworkType) { + self.networkType = networkType + } + + // MARK: Address metadata and validation + static func getAddressMetadata(_ address: String) -> AddressMetadata? { + var networkId: UInt32 = 0 + var addrId: UInt32 = 0 + guard zcashlc_get_address_metadata( + [CChar](address.utf8CString), + &networkId, + &addrId + ) else { + return nil + } + + guard + let network = NetworkType.forNetworkId(networkId), + let addrType = AddressType.forId(addrId) + else { + return nil + } + + return AddressMetadata(network: network, addrType: addrType) + } + + func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] { + guard !address.containsCStringNullBytesBeforeStringEnding() else { + throw RustWeldingError.invalidInput(message: "`address` contains null bytes.") + } + + var len = UInt(0) + + guard let typecodesPointer = zcashlc_get_typecodes_for_unified_address_receivers( + [CChar](address.utf8CString), + &len + ), len > 0 + else { + throw RustWeldingError.malformedStringInput + } + + var typecodes: [UInt32] = [] + + for typecodeIndex in 0 ..< Int(len) { + let pointer = typecodesPointer.advanced(by: typecodeIndex) + + typecodes.append(pointer.pointee) + } + + defer { + zcashlc_free_typecodes(typecodesPointer, len) + } + + return typecodes + } + + func isValidSaplingAddress(_ address: String) -> Bool { + guard !address.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_shielded_address([CChar](address.utf8CString), networkType.networkId) + } + + func isValidSaplingExtendedFullViewingKey(_ key: String) -> Bool { + guard !key.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_viewing_key([CChar](key.utf8CString), networkType.networkId) + } + + func isValidSaplingExtendedSpendingKey(_ key: String) -> Bool { + guard !key.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_sapling_extended_spending_key([CChar](key.utf8CString), networkType.networkId) + } + + func isValidTransparentAddress(_ address: String) -> Bool { + guard !address.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_transparent_address([CChar](address.utf8CString), networkType.networkId) + } + + func isValidUnifiedAddress(_ address: String) -> Bool { + guard !address.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_unified_address([CChar](address.utf8CString), networkType.networkId) + } + + func isValidUnifiedFullViewingKey(_ key: String) -> Bool { + guard !key.containsCStringNullBytesBeforeStringEnding() else { + return false + } + + return zcashlc_is_valid_unified_full_viewing_key([CChar](key.utf8CString), networkType.networkId) + } + + // MARK: Address Derivation + + func deriveUnifiedSpendingKey( + from seed: [UInt8], + accountIndex: Int32 + ) async throws -> UnifiedSpendingKey { + let binaryKeyPtr = seed.withUnsafeBufferPointer { seedBufferPtr in + return zcashlc_derive_spending_key( + seedBufferPtr.baseAddress, + UInt(seed.count), + accountIndex, + networkType.networkId + ) + } + + defer { zcashlc_free_binary_key(binaryKeyPtr) } + + guard let binaryKey = binaryKeyPtr?.pointee else { + throw lastError() ?? .genericError(message: "No error message available") + } + + return binaryKey.unsafeToUnifiedSpendingKey(network: networkType) + } + + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) async throws -> UnifiedFullViewingKey { + let extfvk = try spendingKey.bytes.withUnsafeBufferPointer { uskBufferPtr -> UnsafeMutablePointer in + guard let extfvk = zcashlc_spending_key_to_full_viewing_key( + uskBufferPtr.baseAddress, + UInt(spendingKey.bytes.count), + networkType.networkId + ) else { + throw lastError() ?? .genericError(message: "No error message available") + } + + return extfvk + } + + defer { zcashlc_string_free(extfvk) } + + guard let derived = String(validatingUTF8: extfvk) else { + throw RustWeldingError.unableToDeriveKeys + } + + return UnifiedFullViewingKey(validatedEncoding: derived, account: spendingKey.account) + } + + func getSaplingReceiver(for uAddr: UnifiedAddress) throws -> SaplingAddress { + guard let saplingCStr = zcashlc_get_sapling_receiver_for_unified_address( + [CChar](uAddr.encoding.utf8CString) + ) else { + throw KeyDerivationErrors.invalidUnifiedAddress + } + + defer { zcashlc_string_free(saplingCStr) } + + guard let saplingReceiverStr = String(validatingUTF8: saplingCStr) else { + throw KeyDerivationErrors.receiverNotFound + } + + return SaplingAddress(validatedEncoding: saplingReceiverStr) + } + + func getTransparentReceiver(for uAddr: UnifiedAddress) throws -> TransparentAddress { + guard let transparentCStr = zcashlc_get_transparent_receiver_for_unified_address( + [CChar](uAddr.encoding.utf8CString) + ) else { + throw KeyDerivationErrors.invalidUnifiedAddress + } + + defer { zcashlc_string_free(transparentCStr) } + + guard let transparentReceiverStr = String(validatingUTF8: transparentCStr) else { + throw KeyDerivationErrors.receiverNotFound + } + + return TransparentAddress(validatedEncoding: transparentReceiverStr) + } + + // MARK: Error Handling + + private func lastError() -> RustWeldingError? { + defer { zcashlc_clear_last_error() } + + guard let message = getLastError() else { + return nil + } + + if message.contains("couldn't load Sapling spend parameters") { + return RustWeldingError.saplingSpendParametersNotFound + } else if message.contains("is not empty") { + return RustWeldingError.dataDbNotEmpty + } + + return RustWeldingError.genericError(message: message) + } + + private func getLastError() -> String? { + let errorLen = zcashlc_last_error_length() + if errorLen > 0 { + let error = UnsafeMutablePointer.allocate(capacity: Int(errorLen)) + zcashlc_error_message_utf8(error, errorLen) + zcashlc_clear_last_error() + return String(validatingUTF8: error) + } else { + return nil + } + } +} diff --git a/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackendWelding.swift b/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackendWelding.swift new file mode 100644 index 00000000..b774cce9 --- /dev/null +++ b/Sources/ZcashLightClientKit/Rust/ZcashKeyDerivationBackendWelding.swift @@ -0,0 +1,82 @@ +// +// ZcashKeyDerivationBackendWelding.swift +// +// +// Created by Michal Fousek on 11.04.2023. +// + +import Foundation + +protocol ZcashKeyDerivationBackendWelding { + /// The network type this `ZcashKeyDerivationBackendWelding` implementation is for + var networkType: NetworkType { get } + + /// Returns the network and address type for the given Zcash address string, + /// if the string represents a valid Zcash address. + /// - Note: not `NetworkType` bound + static func getAddressMetadata(_ address: String) -> AddressMetadata? + + /// Obtains the available receiver typecodes for the given String encoded Unified Address + /// - Parameter address: public key represented as a String + /// - Returns the `[UInt32]` that compose the given UA + /// - Throws `RustWeldingError.invalidInput(message: String)` when the UA is either invalid or malformed + /// - Note: not `NetworkType` bound + func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] + + /// Validates the if the given string is a valid Sapling Address + /// - Parameter address: UTF-8 encoded String to validate + /// - Returns: true when the address is valid. Returns false in any other case + /// - Throws: Error when the provided address belongs to another network + func isValidSaplingAddress(_ address: String) -> Bool + + /// Validates the if the given string is a valid Sapling Extended Full Viewing Key + /// - Parameter key: UTF-8 encoded String to validate + /// - Returns: `true` when the Sapling Extended Full Viewing Key is valid. `false` in any other case + /// - Throws: Error when there's another problem not related to validity of the string in question + func isValidSaplingExtendedFullViewingKey(_ key: String) -> Bool + + /// Validates the if the given string is a valid Sapling Extended Spending Key + /// - Returns: `true` when the Sapling Extended Spending Key is valid, false in any other case. + /// - Throws: Error when the key is semantically valid but it belongs to another network + /// - parameter key: String encoded Extended Spending Key + func isValidSaplingExtendedSpendingKey(_ key: String) -> Bool + + /// Validates the if the given string is a valid Transparent Address + /// - Parameter address: UTF-8 encoded String to validate + /// - Returns: true when the address is valid and transparent. false in any other case + func isValidTransparentAddress(_ address: String) -> Bool + + /// validates whether a string encoded address is a valid Unified Address. + /// - Parameter address: UTF-8 encoded String to validate + /// - Returns: true when the address is valid and transparent. false in any other case + func isValidUnifiedAddress(_ address: String) -> Bool + + /// verifies that the given string-encoded `UnifiedFullViewingKey` is valid. + /// - Parameter ufvk: UTF-8 encoded String to validate + /// - Returns: true when the encoded string is a valid UFVK. false in any other case + func isValidUnifiedFullViewingKey(_ ufvk: String) -> Bool + + /// Derives and returns a unified spending key from the given seed for the given account ID. + /// Returns the binary encoding of the spending key. The caller should manage the memory of (and store, if necessary) the returned spending key in a secure fashion. + /// - Parameter seed: a Byte Array with the seed + /// - Parameter accountIndex:account index that the key can spend from + func deriveUnifiedSpendingKey(from seed: [UInt8], accountIndex: Int32) async throws -> UnifiedSpendingKey + + /// Derives a `UnifiedFullViewingKey` from a `UnifiedSpendingKey` + /// - Parameter spendingKey: the `UnifiedSpendingKey` to derive from + /// - Throws: `RustWeldingError.unableToDeriveKeys` if the SDK couldn't derive the UFVK. + /// - Returns: the derived `UnifiedFullViewingKey` + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) async throws -> UnifiedFullViewingKey + + /// Returns the Sapling receiver within the given Unified Address, if any. + /// - Parameter uAddr: a `UnifiedAddress` + /// - Returns a `SaplingAddress` if any + /// - Throws `receiverNotFound` when the receiver is not found. `invalidUnifiedAddress` if the UA provided is not valid + func getSaplingReceiver(for uAddr: UnifiedAddress) throws -> SaplingAddress + + /// Returns the transparent receiver within the given Unified Address, if any. + /// - parameter uAddr: a `UnifiedAddress` + /// - Returns a `TransparentAddress` if any + /// - Throws `receiverNotFound` when the receiver is not found. `invalidUnifiedAddress` if the UA provided is not valid + func getTransparentReceiver(for uAddr: UnifiedAddress) throws -> TransparentAddress +} diff --git a/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift b/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift index bfefe45e..78046ccf 100644 --- a/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift +++ b/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift @@ -9,12 +9,37 @@ import Foundation import libzcashlc -class ZcashRustBackend: ZcashRustBackendWelding { - static let minimumConfirmations: UInt32 = 10 - static let useZIP317Fees = false - static func createAccount(dbData: URL, seed: [UInt8], networkType: NetworkType) async throws -> UnifiedSpendingKey { - let dbData = dbData.osStr() +actor ZcashRustBackend: ZcashRustBackendWelding { + let minimumConfirmations: UInt32 = 10 + let useZIP317Fees = false + let dbData: (String, UInt) + let fsBlockDbRoot: (String, UInt) + let spendParamsPath: (String, UInt) + let outputParamsPath: (String, UInt) + let keyDeriving: ZcashKeyDerivationBackendWelding + + nonisolated let networkType: NetworkType + + /// Creates instance of `ZcashRustBackend`. + /// - Parameters: + /// - dbData: `URL` pointing to file where data database will be. + /// - fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. + /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename + /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. + /// - spendParamsPath: `URL` pointing to spend parameters file. + /// - outputParamsPath: `URL` pointing to output parameters file. + /// - networkType: Network type to use. + init(dbData: URL, fsBlockDbRoot: URL, spendParamsPath: URL, outputParamsPath: URL, networkType: NetworkType) { + self.dbData = dbData.osStr() + self.fsBlockDbRoot = fsBlockDbRoot.osPathStr() + self.spendParamsPath = spendParamsPath.osPathStr() + self.outputParamsPath = outputParamsPath.osPathStr() + self.networkType = networkType + self.keyDeriving = ZcashKeyDerivationBackend(networkType: networkType) + } + + func createAccount(seed: [UInt8]) async throws -> UnifiedSpendingKey { guard let ffiBinaryKeyPtr = zcashlc_create_account( dbData.0, dbData.1, @@ -30,20 +55,13 @@ class ZcashRustBackend: ZcashRustBackendWelding { return ffiBinaryKeyPtr.pointee.unsafeToUnifiedSpendingKey(network: networkType) } - // swiftlint:disable function_parameter_count - static func createToAddress( - dbData: URL, + func createToAddress( usk: UnifiedSpendingKey, to address: String, value: Int64, - memo: MemoBytes?, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 { - let dbData = dbData.osStr() - - return usk.bytes.withUnsafeBufferPointer { uskPtr in + memo: MemoBytes? + ) async throws -> Int64 { + let result = usk.bytes.withUnsafeBufferPointer { uskPtr in zcashlc_create_to_address( dbData.0, dbData.1, @@ -52,55 +70,39 @@ class ZcashRustBackend: ZcashRustBackendWelding { [CChar](address.utf8CString), value, memo?.bytes, - spendParamsPath, - UInt(spendParamsPath.lengthOfBytes(using: .utf8)), - outputParamsPath, - UInt(outputParamsPath.lengthOfBytes(using: .utf8)), + spendParamsPath.0, + spendParamsPath.1, + outputParamsPath.0, + outputParamsPath.1, networkType.networkId, minimumConfirmations, useZIP317Fees ) } + + guard result > 0 else { + throw lastError() ?? .genericError(message: "No error message available") + } + + return result } - static func decryptAndStoreTransaction(dbData: URL, txBytes: [UInt8], minedHeight: Int32, networkType: NetworkType) async -> Bool { - let dbData = dbData.osStr() - return zcashlc_decrypt_and_store_transaction( + func decryptAndStoreTransaction(txBytes: [UInt8], minedHeight: Int32) async throws { + let result = zcashlc_decrypt_and_store_transaction( dbData.0, dbData.1, txBytes, UInt(txBytes.count), UInt32(minedHeight), networkType.networkId - ) != 0 - } + ) - static func deriveUnifiedSpendingKey( - from seed: [UInt8], - accountIndex: Int32, - networkType: NetworkType - ) throws -> UnifiedSpendingKey { - let binaryKeyPtr = seed.withUnsafeBufferPointer { seedBufferPtr in - return zcashlc_derive_spending_key( - seedBufferPtr.baseAddress, - UInt(seed.count), - accountIndex, - networkType.networkId - ) - } - - defer { zcashlc_free_binary_key(binaryKeyPtr) } - - guard let binaryKey = binaryKeyPtr?.pointee else { + guard result != 0 else { throw lastError() ?? .genericError(message: "No error message available") } - - return binaryKey.unsafeToUnifiedSpendingKey(network: networkType) } - static func getBalance(dbData: URL, account: Int32, networkType: NetworkType) async throws -> Int64 { - let dbData = dbData.osStr() - + func getBalance(account: Int32) async throws -> Int64 { let balance = zcashlc_get_balance(dbData.0, dbData.1, account, networkType.networkId) guard balance >= 0 else { @@ -110,13 +112,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { return balance } - static func getCurrentAddress( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> UnifiedAddress { - let dbData = dbData.osStr() - + func getCurrentAddress(account: Int32) async throws -> UnifiedAddress { guard let addressCStr = zcashlc_get_current_address( dbData.0, dbData.1, @@ -132,31 +128,25 @@ class ZcashRustBackend: ZcashRustBackendWelding { throw RustWeldingError.unableToDeriveKeys } - return UnifiedAddress(validatedEncoding: address) + return UnifiedAddress(validatedEncoding: address, networkType: networkType) } - static func getNearestRewindHeight( - dbData: URL, - height: Int32, - networkType: NetworkType - ) async -> Int32 { - let dbData = dbData.osStr() - - return zcashlc_get_nearest_rewind_height( + func getNearestRewindHeight(height: Int32) async throws -> Int32 { + let result = zcashlc_get_nearest_rewind_height( dbData.0, dbData.1, height, networkType.networkId ) + + guard result > 0 else { + throw lastError() ?? .genericError(message: "No error message available") + } + + return result } - static func getNextAvailableAddress( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> UnifiedAddress { - let dbData = dbData.osStr() - + func getNextAvailableAddress(account: Int32) async throws -> UnifiedAddress { guard let addressCStr = zcashlc_get_next_available_address( dbData.0, dbData.1, @@ -172,16 +162,10 @@ class ZcashRustBackend: ZcashRustBackendWelding { throw RustWeldingError.unableToDeriveKeys } - return UnifiedAddress(validatedEncoding: address) + return UnifiedAddress(validatedEncoding: address, networkType: networkType) } - static func getReceivedMemo( - dbData: URL, - idNote: Int64, - networkType: NetworkType - ) async -> Memo? { - let dbData = dbData.osStr() - + func getReceivedMemo(idNote: Int64) async -> Memo? { var contiguousMemoBytes = ContiguousArray(MemoBytes.empty().bytes) var success = false @@ -194,29 +178,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { return (try? MemoBytes(contiguousBytes: contiguousMemoBytes)).flatMap { try? $0.intoMemo() } } - static func getSaplingReceiver(for uAddr: UnifiedAddress) throws -> SaplingAddress { - guard let saplingCStr = zcashlc_get_sapling_receiver_for_unified_address( - [CChar](uAddr.encoding.utf8CString) - ) else { - throw KeyDerivationErrors.invalidUnifiedAddress - } - - defer { zcashlc_string_free(saplingCStr) } - - guard let saplingReceiverStr = String(validatingUTF8: saplingCStr) else { - throw KeyDerivationErrors.receiverNotFound - } - - return SaplingAddress(validatedEncoding: saplingReceiverStr) - } - - static func getSentMemo( - dbData: URL, - idNote: Int64, - networkType: NetworkType - ) async -> Memo? { - let dbData = dbData.osStr() - + func getSentMemo(idNote: Int64) async -> Memo? { var contiguousMemoBytes = ContiguousArray(MemoBytes.empty().bytes) var success = false @@ -229,16 +191,11 @@ class ZcashRustBackend: ZcashRustBackendWelding { return (try? MemoBytes(contiguousBytes: contiguousMemoBytes)).flatMap { try? $0.intoMemo() } } - static func getTransparentBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 { + func getTransparentBalance(account: Int32) async throws -> Int64 { guard account >= 0 else { throw RustWeldingError.invalidInput(message: "Account index must be non-negative") } - let dbData = dbData.osStr() let balance = zcashlc_get_total_transparent_balance_for_account( dbData.0, dbData.1, @@ -253,8 +210,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { return balance } - static func getVerifiedBalance(dbData: URL, account: Int32, networkType: NetworkType) async throws -> Int64 { - let dbData = dbData.osStr() + func getVerifiedBalance(account: Int32) async throws -> Int64 { let balance = zcashlc_get_verified_balance( dbData.0, dbData.1, @@ -270,17 +226,11 @@ class ZcashRustBackend: ZcashRustBackendWelding { return balance } - static func getVerifiedTransparentBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 { + func getVerifiedTransparentBalance(account: Int32) async throws -> Int64 { guard account >= 0 else { throw RustWeldingError.invalidInput(message: "`account` must be non-negative") } - let dbData = dbData.osStr() - let balance = zcashlc_get_verified_transparent_balance_for_account( dbData.0, dbData.1, @@ -299,8 +249,8 @@ class ZcashRustBackend: ZcashRustBackendWelding { return balance } - - static func lastError() -> RustWeldingError? { + + private nonisolated func lastError() -> RustWeldingError? { defer { zcashlc_clear_last_error() } guard let message = getLastError() else { @@ -316,43 +266,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { return RustWeldingError.genericError(message: message) } - static func getAddressMetadata(_ address: String) -> AddressMetadata? { - var networkId: UInt32 = 0 - var addrId: UInt32 = 0 - guard zcashlc_get_address_metadata( - [CChar](address.utf8CString), - &networkId, - &addrId - ) else { - return nil - } - - guard let network = NetworkType.forNetworkId(networkId), - let addrType = AddressType.forId(addrId) - else { - return nil - } - - return AddressMetadata(network: network, addrType: addrType) - } - - static func getTransparentReceiver(for uAddr: UnifiedAddress) throws -> TransparentAddress { - guard let transparentCStr = zcashlc_get_transparent_receiver_for_unified_address( - [CChar](uAddr.encoding.utf8CString) - ) else { - throw KeyDerivationErrors.invalidUnifiedAddress - } - - defer { zcashlc_string_free(transparentCStr) } - - guard let transparentReceiverStr = String(validatingUTF8: transparentCStr) else { - throw KeyDerivationErrors.receiverNotFound - } - - return TransparentAddress(validatedEncoding: transparentReceiverStr) - } - - static func getLastError() -> String? { + private nonisolated func getLastError() -> String? { let errorLen = zcashlc_last_error_length() if errorLen > 0 { let error = UnsafeMutablePointer.allocate(capacity: Int(errorLen)) @@ -364,8 +278,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { } } - static func initDataDb(dbData: URL, seed: [UInt8]?, networkType: NetworkType) async throws -> DbInitResult { - let dbData = dbData.osStr() + func initDataDb(seed: [UInt8]?) async throws -> DbInitResult { switch zcashlc_init_data_database(dbData.0, dbData.1, seed, UInt(seed?.count ?? 0), networkType.networkId) { case 0: // ok return DbInitResult.success @@ -375,74 +288,20 @@ class ZcashRustBackend: ZcashRustBackendWelding { throw throwDataDbError(lastError() ?? .genericError(message: "No error message found")) } } - - static func isValidSaplingAddress(_ address: String, networkType: NetworkType) -> Bool { - guard !address.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_shielded_address([CChar](address.utf8CString), networkType.networkId) - } - - static func isValidTransparentAddress(_ address: String, networkType: NetworkType) -> Bool { - guard !address.containsCStringNullBytesBeforeStringEnding() else { - return false - } - return zcashlc_is_valid_transparent_address([CChar](address.utf8CString), networkType.networkId) - } - - static func isValidSaplingExtendedFullViewingKey(_ key: String, networkType: NetworkType) -> Bool { - guard !key.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_viewing_key([CChar](key.utf8CString), networkType.networkId) - } - - static func isValidSaplingExtendedSpendingKey(_ key: String, networkType: NetworkType) -> Bool { - guard !key.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_sapling_extended_spending_key([CChar](key.utf8CString), networkType.networkId) - } - - static func isValidUnifiedAddress(_ address: String, networkType: NetworkType) -> Bool { - guard !address.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_unified_address([CChar](address.utf8CString), networkType.networkId) - } - - static func isValidUnifiedFullViewingKey(_ key: String, networkType: NetworkType) -> Bool { - guard !key.containsCStringNullBytesBeforeStringEnding() else { - return false - } - - return zcashlc_is_valid_unified_full_viewing_key([CChar](key.utf8CString), networkType.networkId) - } - - static func initAccountsTable( - dbData: URL, - ufvks: [UnifiedFullViewingKey], - networkType: NetworkType - ) async throws { - let dbData = dbData.osStr() - + func initAccountsTable(ufvks: [UnifiedFullViewingKey]) async throws { var ffiUfvks: [FFIEncodedKey] = [] for ufvk in ufvks { guard !ufvk.encoding.containsCStringNullBytesBeforeStringEnding() else { throw RustWeldingError.invalidInput(message: "`UFVK` contains null bytes.") } - - guard self.isValidUnifiedFullViewingKey(ufvk.encoding, networkType: networkType) else { + + guard self.keyDeriving.isValidUnifiedFullViewingKey(ufvk.encoding) else { throw RustWeldingError.invalidInput(message: "UFVK is invalid.") } let ufvkCStr = [CChar](String(ufvk.encoding).utf8CString) - + let ufvkPtr = UnsafeMutablePointer.allocate(capacity: ufvkCStr.count) ufvkPtr.initialize(from: ufvkCStr, count: ufvkCStr.count) @@ -476,19 +335,15 @@ class ZcashRustBackend: ZcashRustBackendWelding { } } - static func initBlockMetadataDb(fsBlockDbRoot: URL) async throws -> Bool { - let blockDb = fsBlockDbRoot.osPathStr() - - let result = zcashlc_init_block_metadata_db(blockDb.0, blockDb.1) + func initBlockMetadataDb() async throws { + let result = zcashlc_init_block_metadata_db(fsBlockDbRoot.0, fsBlockDbRoot.1) guard result else { throw lastError() ?? .genericError(message: "`initAccountsTable` failed with unknown error") } - - return result } - static func writeBlocksMetadata(fsBlockDbRoot: URL, blocks: [ZcashCompactBlock]) async throws -> Bool { + func writeBlocksMetadata(blocks: [ZcashCompactBlock]) async throws { var ffiBlockMetaVec: [FFIBlockMeta] = [] for block in blocks { @@ -507,7 +362,7 @@ class ZcashRustBackend: ZcashRustBackendWelding { hashPtr.deallocate() ffiBlockMetaVec.deallocateElements() } - return false + throw RustWeldingError.writeBlocksMetadataAllocationProblem } ffiBlockMetaVec.append( @@ -530,50 +385,35 @@ class ZcashRustBackend: ZcashRustBackendWelding { defer { ffiBlockMetaVec.deallocateElements() } - let result = try contiguousFFIBlocks.withContiguousMutableStorageIfAvailable { ptr in + try contiguousFFIBlocks.withContiguousMutableStorageIfAvailable { ptr in var meta = FFIBlocksMeta() meta.ptr = ptr.baseAddress meta.len = len fsBlocks.initialize(to: meta) - let fsDb = fsBlockDbRoot.osPathStr() - - let res = zcashlc_write_block_metadata(fsDb.0, fsDb.1, fsBlocks) + let res = zcashlc_write_block_metadata(fsBlockDbRoot.0, fsBlockDbRoot.1, fsBlocks) guard res else { throw lastError() ?? RustWeldingError.genericError(message: "failed to write block metadata") } - - return res } - - guard let value = result else { - return false - } - - return value } - // swiftlint:disable function_parameter_count - static func initBlocksTable( - dbData: URL, + func initBlocksTable( height: Int32, hash: String, time: UInt32, - saplingTree: String, - networkType: NetworkType + saplingTree: String ) async throws { - let dbData = dbData.osStr() - guard !hash.containsCStringNullBytesBeforeStringEnding() else { throw RustWeldingError.invalidInput(message: "`hash` contains null bytes.") } - + guard !saplingTree.containsCStringNullBytesBeforeStringEnding() else { throw RustWeldingError.invalidInput(message: "`saplingTree` contains null bytes.") } - + guard zcashlc_init_blocks_table( dbData.0, dbData.1, @@ -587,19 +427,11 @@ class ZcashRustBackend: ZcashRustBackendWelding { } } - static func latestCachedBlockHeight(fsBlockDbRoot: URL) async -> BlockHeight { - let fsBlockDb = fsBlockDbRoot.osPathStr() - - return BlockHeight(zcashlc_latest_cached_block_height(fsBlockDb.0, fsBlockDb.1)) + func latestCachedBlockHeight() async -> BlockHeight { + return BlockHeight(zcashlc_latest_cached_block_height(fsBlockDbRoot.0, fsBlockDbRoot.1)) } - static func listTransparentReceivers( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> [TransparentAddress] { - let dbData = dbData.osStr() - + func listTransparentReceivers(account: Int32) async throws -> [TransparentAddress] { guard let encodedKeysPtr = zcashlc_list_transparent_receivers( dbData.0, dbData.1, @@ -627,18 +459,14 @@ class ZcashRustBackend: ZcashRustBackendWelding { return addresses } - - static func putUnspentTransparentOutput( - dbData: URL, + + func putUnspentTransparentOutput( txid: [UInt8], index: Int, script: [UInt8], value: Int64, - height: BlockHeight, - networkType: NetworkType - ) async throws -> Bool { - let dbData = dbData.osStr() - + height: BlockHeight + ) async throws { guard zcashlc_put_utxo( dbData.0, dbData.1, @@ -653,48 +481,51 @@ class ZcashRustBackend: ZcashRustBackendWelding { ) else { throw lastError() ?? .genericError(message: "No error message available") } - - return true - } - - static func validateCombinedChain(fsBlockDbRoot: URL, dbData: URL, networkType: NetworkType, limit: UInt32 = 0) async -> Int32 { - let dbCache = fsBlockDbRoot.osPathStr() - let dbData = dbData.osStr() - return zcashlc_validate_combined_chain(dbCache.0, dbCache.1, dbData.0, dbData.1, networkType.networkId, limit) - } - - static func rewindToHeight(dbData: URL, height: Int32, networkType: NetworkType) -> Bool { - let dbData = dbData.osStr() - return zcashlc_rewind_to_height(dbData.0, dbData.1, height, networkType.networkId) } - static func rewindCacheToHeight( - fsBlockDbRoot: URL, - height: Int32 - ) -> Bool { - let fsBlockCache = fsBlockDbRoot.osPathStr() + func validateCombinedChain(limit: UInt32 = 0) async throws { + let result = zcashlc_validate_combined_chain(fsBlockDbRoot.0, fsBlockDbRoot.1, dbData.0, dbData.1, networkType.networkId, limit) - return zcashlc_rewind_fs_block_cache_to_height(fsBlockCache.0, fsBlockCache.1, height) + switch result { + case -1: + return + case 0: + throw RustWeldingError.chainValidationFailed(message: getLastError()) + default: + throw RustWeldingError.invalidChain(upperBound: result) + } } - static func scanBlocks(fsBlockDbRoot: URL, dbData: URL, limit: UInt32 = 0, networkType: NetworkType) async -> Bool { - let dbCache = fsBlockDbRoot.osPathStr() - let dbData = dbData.osStr() - return zcashlc_scan_blocks(dbCache.0, dbCache.1, dbData.0, dbData.1, limit, networkType.networkId) != 0 + func rewindToHeight(height: Int32) async throws { + let result = zcashlc_rewind_to_height(dbData.0, dbData.1, height, networkType.networkId) + + guard result else { + throw lastError() ?? .genericError(message: "No error message available") + } } - - static func shieldFunds( - dbData: URL, + + func rewindCacheToHeight(height: Int32) async throws { + let result = zcashlc_rewind_fs_block_cache_to_height(fsBlockDbRoot.0, fsBlockDbRoot.1, height) + + guard result else { + throw lastError() ?? .genericError(message: "No error message available") + } + } + + func scanBlocks(limit: UInt32 = 0) async throws { + let result = zcashlc_scan_blocks(fsBlockDbRoot.0, fsBlockDbRoot.1, dbData.0, dbData.1, limit, networkType.networkId) + + guard result != 0 else { + throw lastError() ?? .genericError(message: "No error message available") + } + } + + func shieldFunds( usk: UnifiedSpendingKey, memo: MemoBytes?, - shieldingThreshold: Zatoshi, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 { - let dbData = dbData.osStr() - - return usk.bytes.withUnsafeBufferPointer { uskBuffer in + shieldingThreshold: Zatoshi + ) async throws -> Int64 { + let result = usk.bytes.withUnsafeBufferPointer { uskBuffer in zcashlc_shield_funds( dbData.0, dbData.1, @@ -702,82 +533,36 @@ class ZcashRustBackend: ZcashRustBackendWelding { UInt(usk.bytes.count), memo?.bytes, UInt64(shieldingThreshold.amount), - spendParamsPath, - UInt(spendParamsPath.lengthOfBytes(using: .utf8)), - outputParamsPath, - UInt(outputParamsPath.lengthOfBytes(using: .utf8)), + spendParamsPath.0, + spendParamsPath.1, + outputParamsPath.0, + outputParamsPath.1, networkType.networkId, minimumConfirmations, useZIP317Fees ) } - } - - static func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey, networkType: NetworkType) throws -> UnifiedFullViewingKey { - let extfvk = try spendingKey.bytes.withUnsafeBufferPointer { uskBufferPtr -> UnsafeMutablePointer in - guard let extfvk = zcashlc_spending_key_to_full_viewing_key( - uskBufferPtr.baseAddress, - UInt(spendingKey.bytes.count), - networkType.networkId - ) else { - throw lastError() ?? .genericError(message: "No error message available") - } - return extfvk + guard result > 0 else { + throw lastError() ?? .genericError(message: "No error message available") } - defer { zcashlc_string_free(extfvk) } - - guard let derived = String(validatingUTF8: extfvk) else { - throw RustWeldingError.unableToDeriveKeys - } - - return UnifiedFullViewingKey(validatedEncoding: derived, account: spendingKey.account) + return result } - static func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] { - guard !address.containsCStringNullBytesBeforeStringEnding() else { - throw RustWeldingError.invalidInput(message: "`address` contains null bytes.") - } - - var len = UInt(0) - - guard let typecodesPointer = zcashlc_get_typecodes_for_unified_address_receivers( - [CChar](address.utf8CString), - &len - ), len > 0 - else { - throw RustWeldingError.malformedStringInput - } - - var typecodes: [UInt32] = [] - - for typecodeIndex in 0 ..< Int(len) { - let pointer = typecodesPointer.advanced(by: typecodeIndex) - - typecodes.append(pointer.pointee) - } - - defer { - zcashlc_free_typecodes(typecodesPointer, len) - } - - return typecodes - } - - static func consensusBranchIdFor(height: Int32, networkType: NetworkType) throws -> Int32 { + nonisolated func consensusBranchIdFor(height: Int32) throws -> Int32 { let branchId = zcashlc_branch_id_for_height(height, networkType.networkId) - + guard branchId != -1 else { throw RustWeldingError.noConsensusBranchId(height: height) } - + return branchId } } private extension ZcashRustBackend { - static func throwDataDbError(_ error: RustWeldingError) -> Error { + func throwDataDbError(_ error: RustWeldingError) -> Error { if case RustWeldingError.genericError(let message) = error, message.contains("is not empty") { return RustWeldingError.dataDbNotEmpty } @@ -785,7 +570,7 @@ private extension ZcashRustBackend { return RustWeldingError.dataDbInitFailed(message: error.localizedDescription) } - static func throwBalanceError(account: Int32, _ error: RustWeldingError?, fallbackMessage: String) -> Error { + func throwBalanceError(account: Int32, _ error: RustWeldingError?, fallbackMessage: String) -> Error { guard let balanceError = error else { return RustWeldingError.genericError(message: fallbackMessage) } @@ -865,6 +650,15 @@ extension RustWeldingError: LocalizedError { return "`.unableToDeriveKeys` the requested keys could not be derived from the source provided" case let .getBalanceError(account, error): return "`.getBalanceError` could not retrieve balance from account: \(account), error:\(error)" + case let .invalidChain(upperBound: upperBound): + return "`.validateCombinedChain` failed to validate chain. Upper bound: \(upperBound)." + case let .chainValidationFailed(message): + return """ + `.validateCombinedChain` failed to validate chain because of error unrelated to chain validity. \ + Message: \(String(describing: message)) + """ + case .writeBlocksMetadataAllocationProblem: + return "`.writeBlocksMetadata` failed to allocate memory on Swift side necessary to write blocks metadata to db." } } } diff --git a/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift b/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift index e39da64a..fcb100ae 100644 --- a/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift +++ b/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift @@ -19,6 +19,13 @@ enum RustWeldingError: Error { case getBalanceError(Int, Error) case invalidInput(message: String) case invalidRewind(suggestedHeight: Int32) + /// Thrown when `upperBound` if the combined chain is invalid. `upperBound` is the height of the highest invalid block (on the assumption that + /// the highest block in the cache database is correct). + case invalidChain(upperBound: Int32) + /// Thrown if there was an error during validation unrelated to chain validity. + case chainValidationFailed(message: String?) + /// Thrown if there was problem with memory allocation on the Swift side while trying to write blocks metadata to DB. + case writeBlocksMetadataAllocationProblem } enum ZcashRustBackendWeldingConstants { @@ -31,6 +38,7 @@ public enum DbInitResult { case seedRequired } +// sourcery: mockActor protocol ZcashRustBackendWelding { /// Adds the next available account-level spend authority, given the current set of [ZIP 316] /// account identifiers known, to the wallet database. @@ -46,83 +54,34 @@ protocol ZcashRustBackendWelding { /// By convention, wallets should only allow a new account to be generated after funds /// have been received by the currently-available account (in order to enable /// automated account recovery). - /// - parameter dbData: location of the data db /// - parameter seed: byte array of the zip32 seed - /// - parameter networkType: network type of this key /// - Returns: The `UnifiedSpendingKey` structs for the number of accounts created - /// - static func createAccount( - dbData: URL, - seed: [UInt8], - networkType: NetworkType - ) async throws -> UnifiedSpendingKey + func createAccount(seed: [UInt8]) async throws -> UnifiedSpendingKey /// Creates a transaction to the given address from the given account - /// - Parameter dbData: URL for the Data DB /// - Parameter usk: `UnifiedSpendingKey` for the account that controls the funds to be spent. /// - Parameter to: recipient address /// - Parameter value: transaction amount in Zatoshi /// - Parameter memo: the `MemoBytes` for this transaction. pass `nil` when sending to transparent receivers - /// - Parameter spendParamsPath: path escaped String for the filesystem locations where the spend parameters are located - /// - Parameter outputParamsPath: path escaped String for the filesystem locations where the output parameters are located - /// - Parameter networkType: network type of this key - // swiftlint:disable:next function_parameter_count - static func createToAddress( - dbData: URL, + func createToAddress( usk: UnifiedSpendingKey, to address: String, value: Int64, - memo: MemoBytes?, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 // swiftlint:disable function_parameter_count - - /// Scans a transaction for any information that can be decrypted by the accounts in the - /// wallet, and saves it to the wallet. - /// - parameter dbData: location of the data db file - /// - parameter tx: the transaction to decrypt - /// - parameter minedHeight: height on which this transaction was mined. this is used to fetch the consensus branch ID. - /// - parameter networkType: network type of this key - /// returns false if fails to decrypt. - static func decryptAndStoreTransaction( - dbData: URL, - txBytes: [UInt8], - minedHeight: Int32, - networkType: NetworkType - ) async -> Bool - - /// Derives and returns a unified spending key from the given seed for the given account ID. - /// Returns the binary encoding of the spending key. The caller should manage the memory of (and store, if necessary) the returned spending key in a secure fashion. - /// - Parameter seed: a Byte Array with the seed - /// - Parameter accountIndex:account index that the key can spend from - /// - Parameter networkType: network type of this key - /// - Throws `.unableToDerive` when there's an error - static func deriveUnifiedSpendingKey( - from seed: [UInt8], - accountIndex: Int32, - networkType: NetworkType - ) throws -> UnifiedSpendingKey - - /// get the (unverified) balance from the given account - /// - parameter dbData: location of the data db - /// - parameter account: index of the given account - /// - parameter networkType: network type of this key - static func getBalance( - dbData: URL, - account: Int32, - networkType: NetworkType + memo: MemoBytes? ) async throws -> Int64 - /// Returns the most-recently-generated unified payment address for the specified account. - /// - parameter dbData: location of the data db + /// Scans a transaction for any information that can be decrypted by the accounts in the wallet, and saves it to the wallet. + /// - parameter tx: the transaction to decrypt + /// - parameter minedHeight: height on which this transaction was mined. this is used to fetch the consensus branch ID. + func decryptAndStoreTransaction(txBytes: [UInt8], minedHeight: Int32) async throws + + /// Get the (unverified) balance from the given account. /// - parameter account: index of the given account - /// - parameter networkType: network type of this key - static func getCurrentAddress( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> UnifiedAddress + func getBalance(account: Int32) async throws -> Int64 + + /// Returns the most-recently-generated unified payment address for the specified account. + /// - parameter account: index of the given account + func getCurrentAddress(account: Int32) async throws -> UnifiedAddress /// Wallets might need to be rewound because of a reorg, or by user request. /// There are times where the wallet could get out of sync for many reasons and @@ -131,195 +90,64 @@ protocol ZcashRustBackendWelding { /// of sapling witnesses older than 100 blocks. So in order to reconstruct the witness /// tree that allows to spend notes from the given wallet the rewind can't be more than /// 100 blocks or back to the oldest unspent note that this wallet contains. - /// - parameter dbData: location of the data db file /// - parameter height: height you would like to rewind to. - /// - parameter networkType: network type of this key] /// - Returns: the blockheight of the nearest rewind height. - /// - static func getNearestRewindHeight( - dbData: URL, - height: Int32, - networkType: NetworkType - ) async -> Int32 + func getNearestRewindHeight(height: Int32) async throws -> Int32 /// Returns a newly-generated unified payment address for the specified account, with the next available diversifier. - /// - parameter dbData: location of the data db /// - parameter account: index of the given account - /// - parameter networkType: network type of this key - static func getNextAvailableAddress( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> UnifiedAddress + func getNextAvailableAddress(account: Int32) async throws -> UnifiedAddress - /// get received memo from note - /// - parameter dbData: location of the data db file + /// Get received memo from note. /// - parameter idNote: note_id of note where the memo is located - /// - parameter networkType: network type of this key - static func getReceivedMemo( - dbData: URL, - idNote: Int64, - networkType: NetworkType - ) async -> Memo? + func getReceivedMemo(idNote: Int64) async -> Memo? - /// Returns the Sapling receiver within the given Unified Address, if any. - /// - Parameter uAddr: a `UnifiedAddress` - /// - Returns a `SaplingAddress` if any - /// - Throws `receiverNotFound` when the receiver is not found. `invalidUnifiedAddress` if the UA provided is not valid - static func getSaplingReceiver(for uAddr: UnifiedAddress) throws -> SaplingAddress - - /// get sent memo from note - /// - parameter dbData: location of the data db file + /// Get sent memo from note. /// - parameter idNote: note_id of note where the memo is located - /// - parameter networkType: network type of this key /// - Returns: a `Memo` if any - static func getSentMemo( - dbData: URL, - idNote: Int64, - networkType: NetworkType - ) async -> Memo? + func getSentMemo(idNote: Int64) async -> Memo? /// Get the verified cached transparent balance for the given address - /// - parameter dbData: location of the data db file /// - parameter account; the account index to query - /// - parameter networkType: network type of this key - static func getTransparentBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 + func getTransparentBalance(account: Int32) async throws -> Int64 - /// Returns the transparent receiver within the given Unified Address, if any. - /// - parameter uAddr: a `UnifiedAddress` - /// - Returns a `TransparentAddress` if any - /// - Throws `receiverNotFound` when the receiver is not found. `invalidUnifiedAddress` if the UA provided is not valid - static func getTransparentReceiver(for uAddr: UnifiedAddress) throws -> TransparentAddress - - /// gets the latest error if available. Clear the existing error - /// - Returns a `RustWeldingError` if exists - static func lastError() -> RustWeldingError? - - /// gets the latest error message from librustzcash. Does not clear existing error - static func getLastError() -> String? - - /// initialize the accounts table from a set of unified full viewing keys - /// - Note: this function should only be used when restoring an existing seed phrase. - /// when creating a new wallet, use `createAccount()` instead - /// - Parameter dbData: location of the data db + /// Initialize the accounts table from a set of unified full viewing keys. + /// - Note: this function should only be used when restoring an existing seed phrase. when creating a new wallet, use `createAccount()` instead. /// - Parameter ufvks: an array of UnifiedFullViewingKeys - /// - Parameter networkType: network type of this key - static func initAccountsTable( - dbData: URL, - ufvks: [UnifiedFullViewingKey], - networkType: NetworkType - ) async throws + func initAccountsTable(ufvks: [UnifiedFullViewingKey]) async throws - /// initializes the data db. This will performs any migrations needed on the sqlite file + /// Initializes the data db. This will performs any migrations needed on the sqlite file /// provided. Some migrations might need that callers provide the seed bytes. - /// - Parameter dbData: location of the data db sql file /// - Parameter seed: ZIP-32 compliant seed bytes for this wallet - /// - Parameter networkType: network type of this key /// - Returns: `DbInitResult.success` if the dataDb was initialized successfully /// or `DbInitResult.seedRequired` if the operation requires the seed to be passed /// in order to be completed successfully. - static func initDataDb( - dbData: URL, - seed: [UInt8]?, - networkType: NetworkType - ) async throws -> DbInitResult + func initDataDb(seed: [UInt8]?) async throws -> DbInitResult - /// Returns the network and address type for the given Zcash address string, - /// if the string represents a valid Zcash address. - static func getAddressMetadata(_ address: String) -> AddressMetadata? - - /// Validates the if the given string is a valid Sapling Address - /// - Parameter address: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: true when the address is valid. Returns false in any other case - /// - Throws: Error when the provided address belongs to another network - static func isValidSaplingAddress(_ address: String, networkType: NetworkType) -> Bool - - /// Validates the if the given string is a valid Sapling Extended Full Viewing Key - /// - Parameter key: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: `true` when the Sapling Extended Full Viewing Key is valid. `false` in any other case - /// - Throws: Error when there's another problem not related to validity of the string in question - static func isValidSaplingExtendedFullViewingKey(_ key: String, networkType: NetworkType) -> Bool - - /// Validates the if the given string is a valid Sapling Extended Spending Key - /// - Returns: `true` when the Sapling Extended Spending Key is valid, false in any other case. - /// - Throws: Error when the key is semantically valid but it belongs to another network - /// - parameter key: String encoded Extended Spending Key - /// - parameter networkType: `NetworkType` signaling testnet or mainnet - static func isValidSaplingExtendedSpendingKey(_ key: String, networkType: NetworkType) -> Bool - - /// Validates the if the given string is a valid Transparent Address - /// - Parameter address: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: true when the address is valid and transparent. false in any other case - /// - Throws: Error when the provided address belongs to another network - static func isValidTransparentAddress(_ address: String, networkType: NetworkType) -> Bool - - /// validates whether a string encoded address is a valid Unified Address. - /// - Parameter address: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: true when the address is valid and transparent. false in any other case - /// - Throws: Error when the provided address belongs to another network - static func isValidUnifiedAddress(_ address: String, networkType: NetworkType) -> Bool - - /// verifies that the given string-encoded `UnifiedFullViewingKey` is valid. - /// - Parameter ufvk: UTF-8 encoded String to validate - /// - Parameter networkType: network type of this key - /// - Returns: true when the encoded string is a valid UFVK. false in any other case - /// - Throws: Error when there's another problem not related to validity of the string in question - static func isValidUnifiedFullViewingKey(_ ufvk: String, networkType: NetworkType) -> Bool - - /// initialize the blocks table from a given checkpoint (heigh, hash, time, saplingTree and networkType) - /// - parameter dbData: location of the data db + /// Initialize the blocks table from a given checkpoint (heigh, hash, time, saplingTree and networkType). /// - parameter height: represents the block height of the given checkpoint /// - parameter hash: hash of the merkle tree /// - parameter time: in milliseconds from reference /// - parameter saplingTree: hash of the sapling tree - /// - parameter networkType: `NetworkType` signaling testnet or mainnet - static func initBlocksTable( - dbData: URL, + func initBlocksTable( height: Int32, hash: String, time: UInt32, - saplingTree: String, - networkType: NetworkType - ) async throws // swiftlint:disable function_parameter_count + saplingTree: String + ) async throws /// Returns a list of the transparent receivers for the diversified unified addresses that have /// been allocated for the provided account. - /// - parameter dbData: location of the data db /// - parameter account: index of the given account - /// - parameter networkType: the network type - static func listTransparentReceivers( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> [TransparentAddress] + func listTransparentReceivers(account: Int32) async throws -> [TransparentAddress] - /// get the verified balance from the given account - /// - parameter dbData: location of the data db + /// Get the verified balance from the given account /// - parameter account: index of the given account - /// - parameter networkType: the network type - static func getVerifiedBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 + func getVerifiedBalance(account: Int32) async throws -> Int64 /// Get the verified cached transparent balance for the given account - /// - parameter dbData: location of the data db /// - parameter account: account index to query the balance for. - /// - parameter networkType: the network type - static func getVerifiedTransparentBalance( - dbData: URL, - account: Int32, - networkType: NetworkType - ) async throws -> Int64 + func getVerifiedTransparentBalance(account: Int32) async throws -> Int64 /// Checks that the scanned blocks in the data database, when combined with the recent /// `CompactBlock`s in the cache database, form a valid chain. @@ -333,39 +161,22 @@ protocol ZcashRustBackendWelding { /// - parameter dbData: location of the data db file /// - parameter networkType: the network type /// - parameter limit: a limit to validate a fixed number of blocks instead of the whole cache. - /// - Returns: - /// - `-1` if the combined chain is valid. - /// - `upper_bound` if the combined chain is invalid. - /// - `upper_bound` is the height of the highest invalid block (on the assumption that the highest block in the cache database is correct). - /// - `0` if there was an error during validation unrelated to chain validity. + /// - Throws: + /// - `RustWeldingError.chainValidationFailed` if there was an error during validation unrelated to chain validity. + /// - `RustWeldingError.invalidChain(upperBound)` if the combined chain is invalid. `upperBound` is the height of the highest invalid block + /// (on the assumption that the highest block in the cache database is correct). + /// /// - Important: This function does not mutate either of the databases. - static func validateCombinedChain( - fsBlockDbRoot: URL, - dbData: URL, - networkType: NetworkType, - limit: UInt32 - ) async -> Int32 + func validateCombinedChain(limit: UInt32) async throws /// Resets the state of the database to only contain block and transaction information up to the given height. clears up all derived data as well - /// - parameter dbData: `URL` pointing to the filesystem root directory where the fsBlock cache is. - /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename - /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. /// - parameter height: height to rewind to. - static func rewindToHeight( - dbData: URL, - height: Int32, - networkType: NetworkType - ) async -> Bool + func rewindToHeight(height: Int32) async throws /// Resets the state of the FsBlock database to only contain block and transaction information up to the given height. /// - Note: this does not delete the files. Only rolls back the database. - /// - parameter fsBlockDbRoot: location of the data db file - /// - parameter height: height to rewind to. DON'T PASS ARBITRARY HEIGHT. Use getNearestRewindHeight when unsure - /// - parameter networkType: the network type - static func rewindCacheToHeight( - fsBlockDbRoot: URL, - height: Int32 - ) async -> Bool + /// - parameter height: height to rewind to. DON'T PASS ARBITRARY HEIGHT. Use `getNearestRewindHeight` when unsure + func rewindCacheToHeight(height: Int32) async throws /// Scans new blocks added to the cache for any transactions received by the tracked /// accounts. @@ -379,101 +190,48 @@ protocol ZcashRustBackendWelding { /// Scanned blocks are required to be height-sequential. If a block is missing from the /// cache, an error will be signalled. /// - /// - parameter fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. - /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename - /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. - /// - parameter dbData: location of the data db sqlite file /// - parameter limit: scan up to limit blocks. pass 0 to set no limit. - /// - parameter networkType: the network type - /// returns false if fails to scan. - static func scanBlocks( - fsBlockDbRoot: URL, - dbData: URL, - limit: UInt32, - networkType: NetworkType - ) async -> Bool + func scanBlocks(limit: UInt32) async throws /// Upserts a UTXO into the data db database - /// - parameter dbData: location of the data db file /// - parameter txid: the txid bytes for the UTXO /// - parameter index: the index of the UTXO /// - parameter script: the script of the UTXO /// - parameter value: the value of the UTXO /// - parameter height: the mined height for the UTXO - /// - parameter networkType: the network type - /// - Returns: true if the operation succeded or false otherwise - static func putUnspentTransparentOutput( - dbData: URL, + func putUnspentTransparentOutput( txid: [UInt8], index: Int, script: [UInt8], value: Int64, - height: BlockHeight, - networkType: NetworkType - ) async throws -> Bool + height: BlockHeight + ) async throws /// Creates a transaction to shield all found UTXOs in data db for the account the provided `UnifiedSpendingKey` has spend authority for. - /// - Parameter dbData: URL for the Data DB /// - Parameter usk: `UnifiedSpendingKey` that spend transparent funds and where the funds will be shielded to. /// - Parameter memo: the `Memo` for this transaction - /// - Parameter spendParamsPath: path escaped String for the filesystem locations where the spend parameters are located - /// - Parameter outputParamsPath: path escaped String for the filesystem locations where the output parameters are located - /// - Parameter networkType: the network type - static func shieldFunds( - dbData: URL, + func shieldFunds( usk: UnifiedSpendingKey, memo: MemoBytes?, - shieldingThreshold: Zatoshi, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 - - /// Obtains the available receiver typecodes for the given String encoded Unified Address - /// - Parameter address: public key represented as a String - /// - Returns the `[UInt32]` that compose the given UA - /// - Throws `RustWeldingError.invalidInput(message: String)` when the UA is either invalid or malformed - static func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] + shieldingThreshold: Zatoshi + ) async throws -> Int64 /// Gets the consensus branch id for the given height /// - Parameter height: the height you what to know the branch id for - /// - Parameter networkType: the network type - static func consensusBranchIdFor( - height: Int32, - networkType: NetworkType - ) throws -> Int32 + func consensusBranchIdFor(height: Int32) throws -> Int32 - /// Derives a `UnifiedFullViewingKey` from a `UnifiedSpendingKey` - /// - Parameter spendingKey: the `UnifiedSpendingKey` to derive from - /// - Parameter networkType: the network type - /// - Throws: `RustWeldingError.unableToDeriveKeys` if the SDK couldn't derive the UFVK. - /// - Returns: the derived `UnifiedFullViewingKey` - static func deriveUnifiedFullViewingKey( - from spendingKey: UnifiedSpendingKey, - networkType: NetworkType - ) throws -> UnifiedFullViewingKey - - /// initializes Filesystem based block cache - /// - Parameter fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. - /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename - /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. - /// - returns `true` when successful, `false` when fails but no throwing information was found + /// Initializes Filesystem based block cache /// - throws `RustWeldingError` when fails to initialize - static func initBlockMetadataDb(fsBlockDbRoot: URL) async throws -> Bool + func initBlockMetadataDb() async throws /// Write compact block metadata to a database known to the Rust layer - /// - Parameter fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. - /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename - /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. - /// - Parameter blocks: The `ZcashCompactBlock`s that are going to be marked as stored by the - /// metadata Db. - /// - Returns `true` if the operation was successful, `false` otherwise. - static func writeBlocksMetadata(fsBlockDbRoot: URL, blocks: [ZcashCompactBlock]) async throws -> Bool + /// - Parameter blocks: The `ZcashCompactBlock`s that are going to be marked as stored by the metadata Db. + func writeBlocksMetadata(blocks: [ZcashCompactBlock]) async throws /// Gets the latest block height stored in the filesystem based cache. /// - Parameter fsBlockDbRoot: `URL` pointing to the filesystem root directory where the fsBlock cache is. /// this directory is expected to contain a `/blocks` sub-directory with the blocks stored in the convened filename /// format `{height}-{hash}-block`. This directory has must be granted both write and read permissions. /// - Returns `BlockHeight` of the latest cached block or `.empty` if no blocks are stored. - static func latestCachedBlockHeight(fsBlockDbRoot: URL) async -> BlockHeight + func latestCachedBlockHeight() async -> BlockHeight } diff --git a/Sources/ZcashLightClientKit/Synchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer.swift index fed08df2..bbbcc5b0 100644 --- a/Sources/ZcashLightClientKit/Synchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer.swift @@ -230,24 +230,34 @@ public protocol Synchronizer: AnyObject { func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository /// Get all memos for `transaction`. + /// + // sourcery: mockedName="getMemosForClearedTransaction" func getMemos(for transaction: ZcashTransaction.Overview) async throws -> [Memo] /// Get all memos for `receivedTransaction`. + /// + // sourcery: mockedName="getMemosForReceivedTransaction" func getMemos(for receivedTransaction: ZcashTransaction.Received) async throws -> [Memo] /// Get all memos for `sentTransaction`. + /// + // sourcery: mockedName="getMemosForSentTransaction" func getMemos(for sentTransaction: ZcashTransaction.Sent) async throws -> [Memo] /// Attempt to get recipients from a Transaction Overview. /// - parameter transaction: A transaction overview /// - returns the recipients or an empty array if no recipients are found on this transaction because it's not an outgoing /// transaction + /// + // sourcery: mockedName="getRecipientsForClearedTransaction" func getRecipients(for transaction: ZcashTransaction.Overview) async -> [TransactionRecipient] /// Get the recipients for the given a sent transaction /// - parameter transaction: A transaction overview /// - returns the recipients or an empty array if no recipients are found on this transaction because it's not an outgoing /// transaction + /// + // sourcery: mockedName="getRecipientsForSentTransaction" func getRecipients(for transaction: ZcashTransaction.Sent) async -> [TransactionRecipient] /// Returns a list of confirmed transactions that preceed the given transaction with a limit count. diff --git a/Sources/ZcashLightClientKit/Synchronizer/ClosureSDKSynchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer/ClosureSDKSynchronizer.swift index 915c9719..d5fff219 100644 --- a/Sources/ZcashLightClientKit/Synchronizer/ClosureSDKSynchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer/ClosureSDKSynchronizer.swift @@ -37,37 +37,37 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { walletBirthday: BlockHeight, completion: @escaping (Result) -> Void ) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { return try await self.synchronizer.prepare(with: seed, viewingKeys: viewingKeys, walletBirthday: walletBirthday) } } public func start(retry: Bool, completion: @escaping (Error?) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.start(retry: retry) } } public func stop(completion: @escaping () -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.stop() } } public func getSaplingAddress(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getSaplingAddress(accountIndex: accountIndex) } } public func getUnifiedAddress(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getUnifiedAddress(accountIndex: accountIndex) } } public func getTransparentAddress(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getTransparentAddress(accountIndex: accountIndex) } } @@ -79,7 +79,7 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { memo: Memo?, completion: @escaping (Result) -> Void ) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.sendToAddress(spendingKey: spendingKey, zatoshi: zatoshi, toAddress: toAddress, memo: memo) } } @@ -90,37 +90,37 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { shieldingThreshold: Zatoshi, completion: @escaping (Result) -> Void ) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.shieldFunds(spendingKey: spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) } } public func cancelSpend(transaction: PendingTransactionEntity, completion: @escaping (Bool) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.cancelSpend(transaction: transaction) } } public func pendingTransactions(completion: @escaping ([PendingTransactionEntity]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.pendingTransactions } } public func clearedTransactions(completion: @escaping ([ZcashTransaction.Overview]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.clearedTransactions } } public func sentTranscations(completion: @escaping ([ZcashTransaction.Sent]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.sentTransactions } } public func receivedTransactions(completion: @escaping ([ZcashTransaction.Received]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.receivedTransactions } } @@ -128,31 +128,31 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { public func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository { synchronizer.paginatedTransactions(of: kind) } public func getMemos(for transaction: ZcashTransaction.Overview, completion: @escaping (Result<[Memo], Error>) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getMemos(for: transaction) } } public func getMemos(for receivedTransaction: ZcashTransaction.Received, completion: @escaping (Result<[Memo], Error>) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getMemos(for: receivedTransaction) } } public func getMemos(for sentTransaction: ZcashTransaction.Sent, completion: @escaping (Result<[Memo], Error>) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getMemos(for: sentTransaction) } } public func getRecipients(for transaction: ZcashTransaction.Overview, completion: @escaping ([TransactionRecipient]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.getRecipients(for: transaction) } } public func getRecipients(for transaction: ZcashTransaction.Sent, completion: @escaping ([TransactionRecipient]) -> Void) { - executeAction(completion) { + AsyncToClosureGateway.executeAction(completion) { await self.synchronizer.getRecipients(for: transaction) } } @@ -162,37 +162,37 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { limit: Int, completion: @escaping (Result<[ZcashTransaction.Overview], Error>) -> Void ) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.allConfirmedTransactions(from: transaction, limit: limit) } } public func latestHeight(completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.latestHeight() } } public func refreshUTXOs(address: TransparentAddress, from height: BlockHeight, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.refreshUTXOs(address: address, from: height) } } public func getTransparentBalance(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getTransparentBalance(accountIndex: accountIndex) } } public func getShieldedBalance(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getShieldedBalance(accountIndex: accountIndex) } } public func getShieldedVerifiedBalance(accountIndex: Int, completion: @escaping (Result) -> Void) { - executeThrowingAction(completion) { + AsyncToClosureGateway.executeThrowingAction(completion) { try await self.synchronizer.getShieldedVerifiedBalance(accountIndex: accountIndex) } } @@ -204,41 +204,3 @@ extension ClosureSDKSynchronizer: ClosureSynchronizer { public func rewind(_ policy: RewindPolicy) -> CompletablePublisher { synchronizer.rewind(policy) } public func wipe() -> CompletablePublisher { synchronizer.wipe() } } - -extension ClosureSDKSynchronizer { - private func executeAction(_ completion: @escaping () -> Void, action: @escaping () async -> Void) { - Task { - await action() - completion() - } - } - - private func executeAction(_ completion: @escaping (R) -> Void, action: @escaping () async -> R) { - Task { - let result = await action() - completion(result) - } - } - - private func executeThrowingAction(_ completion: @escaping (Error?) -> Void, action: @escaping () async throws -> Void) { - Task { - do { - try await action() - completion(nil) - } catch { - completion(error) - } - } - } - - private func executeThrowingAction(_ completion: @escaping (Result) -> Void, action: @escaping () async throws -> R) { - Task { - do { - let result = try await action() - completion(.success(result)) - } catch { - completion(.failure(error)) - } - } - } -} diff --git a/Sources/ZcashLightClientKit/Synchronizer/CombineSDKSynchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer/CombineSDKSynchronizer.swift index f725050e..ea531dce 100644 --- a/Sources/ZcashLightClientKit/Synchronizer/CombineSDKSynchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer/CombineSDKSynchronizer.swift @@ -36,37 +36,37 @@ extension CombineSDKSynchronizer: CombineSynchronizer { viewingKeys: [UnifiedFullViewingKey], walletBirthday: BlockHeight ) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { return try await self.synchronizer.prepare(with: seed, viewingKeys: viewingKeys, walletBirthday: walletBirthday) } } public func start(retry: Bool) -> CompletablePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.start(retry: retry) } } public func stop() -> CompletablePublisher { - return executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.stop() } } public func getSaplingAddress(accountIndex: Int) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getSaplingAddress(accountIndex: accountIndex) } } public func getUnifiedAddress(accountIndex: Int) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getUnifiedAddress(accountIndex: accountIndex) } } public func getTransparentAddress(accountIndex: Int) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getTransparentAddress(accountIndex: accountIndex) } } @@ -77,7 +77,7 @@ extension CombineSDKSynchronizer: CombineSynchronizer { toAddress: Recipient, memo: Memo? ) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.sendToAddress(spendingKey: spendingKey, zatoshi: zatoshi, toAddress: toAddress, memo: memo) } } @@ -87,37 +87,37 @@ extension CombineSDKSynchronizer: CombineSynchronizer { memo: Memo, shieldingThreshold: Zatoshi ) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.shieldFunds(spendingKey: spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) } } public func cancelSpend(transaction: PendingTransactionEntity) -> SinglePublisher { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.cancelSpend(transaction: transaction) } } public var pendingTransactions: SinglePublisher<[PendingTransactionEntity], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.pendingTransactions } } public var clearedTransactions: SinglePublisher<[ZcashTransaction.Overview], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.clearedTransactions } } public var sentTransactions: SinglePublisher<[ZcashTransaction.Sent], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.sentTransactions } } public var receivedTransactions: SinglePublisher<[ZcashTransaction.Received], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.receivedTransactions } } @@ -125,67 +125,67 @@ extension CombineSDKSynchronizer: CombineSynchronizer { public func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository { synchronizer.paginatedTransactions(of: kind) } public func getMemos(for transaction: ZcashTransaction.Overview) -> SinglePublisher<[Memo], Error> { - executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getMemos(for: transaction) } } public func getMemos(for receivedTransaction: ZcashTransaction.Received) -> SinglePublisher<[Memo], Error> { - executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getMemos(for: receivedTransaction) } } public func getMemos(for sentTransaction: ZcashTransaction.Sent) -> SinglePublisher<[Memo], Error> { - executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getMemos(for: sentTransaction) } } public func getRecipients(for transaction: ZcashTransaction.Overview) -> SinglePublisher<[TransactionRecipient], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.getRecipients(for: transaction) } } public func getRecipients(for transaction: ZcashTransaction.Sent) -> SinglePublisher<[TransactionRecipient], Never> { - executeAction() { + AsyncToCombineGateway.executeAction() { await self.synchronizer.getRecipients(for: transaction) } } public func allConfirmedTransactions(from transaction: ZcashTransaction.Overview, limit: Int) -> SinglePublisher<[ZcashTransaction.Overview], Error> { - executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.allConfirmedTransactions(from: transaction, limit: limit) } } public func latestHeight() -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.latestHeight() } } public func refreshUTXOs(address: TransparentAddress, from height: BlockHeight) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.refreshUTXOs(address: address, from: height) } } public func getTransparentBalance(accountIndex: Int) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await self.synchronizer.getTransparentBalance(accountIndex: accountIndex) } } public func getShieldedBalance(accountIndex: Int = 0) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await synchronizer.getShieldedBalance(accountIndex: accountIndex) } } public func getShieldedVerifiedBalance(accountIndex: Int = 0) -> SinglePublisher { - return executeThrowingAction() { + AsyncToCombineGateway.executeThrowingAction() { try await synchronizer.getShieldedVerifiedBalance(accountIndex: accountIndex) } } @@ -193,53 +193,3 @@ extension CombineSDKSynchronizer: CombineSynchronizer { public func rewind(_ policy: RewindPolicy) -> CompletablePublisher { synchronizer.rewind(policy) } public func wipe() -> CompletablePublisher { synchronizer.wipe() } } - -extension CombineSDKSynchronizer { - private func executeAction(action: @escaping () async -> Void) -> CompletablePublisher { - let subject = PassthroughSubject() - Task { - await action() - subject.send(completion: .finished) - } - return subject.eraseToAnyPublisher() - } - - private func executeAction(action: @escaping () async -> R) -> SinglePublisher { - let subject = PassthroughSubject() - Task { - let result = await action() - subject.send(result) - subject.send(completion: .finished) - } - return subject.eraseToAnyPublisher() - } - - private func executeThrowingAction(action: @escaping () async throws -> Void) -> CompletablePublisher { - let subject = PassthroughSubject() - Task { - do { - try await action() - subject.send(completion: .finished) - } catch { - subject.send(completion: .failure(error)) - } - } - - return subject.eraseToAnyPublisher() - } - - private func executeThrowingAction(action: @escaping () async throws -> R) -> SinglePublisher { - let subject = PassthroughSubject() - Task { - do { - let result = try await action() - subject.send(result) - subject.send(completion: .finished) - } catch { - subject.send(completion: .failure(error)) - } - } - - return subject.eraseToAnyPublisher() - } -} diff --git a/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift index a721f7b7..6803a693 100644 --- a/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift @@ -458,21 +458,13 @@ public class SDKSynchronizer: Synchronizer { } public func getShieldedBalance(accountIndex: Int = 0) async throws -> Zatoshi { - let balance = try await initializer.rustBackend.getBalance( - dbData: initializer.dataDbURL, - account: Int32(accountIndex), - networkType: network.networkType - ) + let balance = try await initializer.rustBackend.getBalance(account: Int32(accountIndex)) return Zatoshi(balance) } public func getShieldedVerifiedBalance(accountIndex: Int = 0) async throws -> Zatoshi { - let balance = try await initializer.rustBackend.getVerifiedBalance( - dbData: initializer.dataDbURL, - account: Int32(accountIndex), - networkType: network.networkType - ) + let balance = try await initializer.rustBackend.getVerifiedBalance(account: Int32(accountIndex)) return Zatoshi(balance) } @@ -617,7 +609,6 @@ public class SDKSynchronizer: Synchronizer { newState = await snapshotState(status: newStatus) } else { - newState = await SynchronizerState( syncSessionID: syncSession.value, shieldedBalance: latestState.shieldedBalance, diff --git a/Sources/ZcashLightClientKit/Tool/DerivationTool.swift b/Sources/ZcashLightClientKit/Tool/DerivationTool.swift index 6bfa8b7b..5e4d501d 100644 --- a/Sources/ZcashLightClientKit/Tool/DerivationTool.swift +++ b/Sources/ZcashLightClientKit/Tool/DerivationTool.swift @@ -5,17 +5,14 @@ // Created by Francisco Gindre on 10/8/20. // +import Combine import Foundation public protocol KeyValidation { func isValidUnifiedFullViewingKey(_ ufvk: String) -> Bool - func isValidTransparentAddress(_ tAddress: String) -> Bool - func isValidSaplingAddress(_ zAddress: String) -> Bool - func isValidSaplingExtendedSpendingKey(_ extsk: String) -> Bool - func isValidUnifiedAddress(_ unifiedAddress: String) -> Bool } @@ -25,26 +22,39 @@ public protocol KeyDeriving { /// - Parameter accountNumber: `Int` with the account number /// - Throws `.unableToDerive` if there's a problem deriving this key /// - Returns a `UnifiedSpendingKey` - func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) throws -> UnifiedSpendingKey + func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) async throws -> UnifiedSpendingKey + func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int, completion: @escaping (Result) -> Void) + func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) -> SinglePublisher + + /// Given a spending key, return the associated viewing key. + /// - Parameter spendingKey: the `UnifiedSpendingKey` from which to derive the `UnifiedFullViewingKey` from. + /// - Returns: the viewing key that corresponds to the spending key. + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) async throws -> UnifiedFullViewingKey + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey, completion: @escaping (Result) -> Void) + func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) -> SinglePublisher /// Extracts the `SaplingAddress` from the given `UnifiedAddress` /// - Parameter address: the `UnifiedAddress` /// - Throws `KeyDerivationErrors.receiverNotFound` if the receiver is not present - static func saplingReceiver(from unifiedAddress: UnifiedAddress) throws -> SaplingAddress + func saplingReceiver(from unifiedAddress: UnifiedAddress) throws -> SaplingAddress /// Extracts the `TransparentAddress` from the given `UnifiedAddress` /// - Parameter address: the `UnifiedAddress` /// - Throws `KeyDerivationErrors.receiverNotFound` if the receiver is not present - static func transparentReceiver(from unifiedAddress: UnifiedAddress) throws -> TransparentAddress + func transparentReceiver(from unifiedAddress: UnifiedAddress) throws -> TransparentAddress /// Extracts the `UnifiedAddress.ReceiverTypecodes` from the given `UnifiedAddress` /// - Parameter address: the `UnifiedAddress` /// - Throws - static func receiverTypecodesFromUnifiedAddress(_ address: UnifiedAddress) throws -> [UnifiedAddress.ReceiverTypecodes] + func receiverTypecodesFromUnifiedAddress(_ address: UnifiedAddress) throws -> [UnifiedAddress.ReceiverTypecodes] + + static func getAddressMetadata(_ addr: String) -> AddressMetadata? } public enum KeyDerivationErrors: Error { case derivationError(underlyingError: Error) + // When something happens that is not related to derivation itself happens. For example if self is nil in closure. + case genericOtherError case unableToDerive case invalidInput case invalidUnifiedAddress @@ -52,31 +62,43 @@ public enum KeyDerivationErrors: Error { } public class DerivationTool: KeyDeriving { - static var rustwelding: ZcashRustBackendWelding.Type = ZcashRustBackend.self + let backend: ZcashKeyDerivationBackendWelding - var networkType: NetworkType - - public init(networkType: NetworkType) { - self.networkType = networkType + init(networkType: NetworkType) { + self.backend = ZcashKeyDerivationBackend(networkType: networkType) } - public static func saplingReceiver(from unifiedAddress: UnifiedAddress) throws -> SaplingAddress { - try rustwelding.getSaplingReceiver(for: unifiedAddress) + public func saplingReceiver(from unifiedAddress: UnifiedAddress) throws -> SaplingAddress { + try backend.getSaplingReceiver(for: unifiedAddress) } - public static func transparentReceiver(from unifiedAddress: UnifiedAddress) throws -> TransparentAddress { - try rustwelding.getTransparentReceiver(for: unifiedAddress) + public func transparentReceiver(from unifiedAddress: UnifiedAddress) throws -> TransparentAddress { + try backend.getTransparentReceiver(for: unifiedAddress) } public static func getAddressMetadata(_ addr: String) -> AddressMetadata? { - rustwelding.getAddressMetadata(addr) + ZcashKeyDerivationBackend.getAddressMetadata(addr) } /// Given a spending key, return the associated viewing key. /// - Parameter spendingKey: the `UnifiedSpendingKey` from which to derive the `UnifiedFullViewingKey` from. /// - Returns: the viewing key that corresponds to the spending key. - public func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) throws -> UnifiedFullViewingKey { - try DerivationTool.rustwelding.deriveUnifiedFullViewingKey(from: spendingKey, networkType: self.networkType) + public func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) async throws -> UnifiedFullViewingKey { + try await backend.deriveUnifiedFullViewingKey(from: spendingKey) + } + + public func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey, completion: @escaping (Result) -> Void) { + AsyncToClosureGateway.executeThrowingAction(completion) { [weak self] in + guard let self else { throw KeyDerivationErrors.genericOtherError } + return try await self.deriveUnifiedFullViewingKey(from: spendingKey) + } + } + + public func deriveUnifiedFullViewingKey(from spendingKey: UnifiedSpendingKey) -> SinglePublisher { + AsyncToCombineGateway.executeThrowingAction() { [weak self] in + guard let self else { throw KeyDerivationErrors.genericOtherError } + return try await self.deriveUnifiedFullViewingKey(from: spendingKey) + } } /// Given a seed and a number of accounts, return the associated spending keys. @@ -84,24 +106,34 @@ public class DerivationTool: KeyDeriving { /// - Parameter numberOfAccounts: the number of accounts to use. Multiple accounts are not fully /// supported so the default value of 1 is recommended. /// - Returns: the spending keys that correspond to the seed, formatted as Strings. - public func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) throws -> UnifiedSpendingKey { + public func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) async throws -> UnifiedSpendingKey { guard accountIndex >= 0, let accountIndex = Int32(exactly: accountIndex) else { throw KeyDerivationErrors.invalidInput } do { - return try DerivationTool.rustwelding.deriveUnifiedSpendingKey( - from: seed, - accountIndex: accountIndex, - networkType: self.networkType - ) + return try await backend.deriveUnifiedSpendingKey(from: seed, accountIndex: accountIndex) } catch { throw KeyDerivationErrors.unableToDerive } } - public static func receiverTypecodesFromUnifiedAddress(_ address: UnifiedAddress) throws -> [UnifiedAddress.ReceiverTypecodes] { + public func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int, completion: @escaping (Result) -> Void) { + AsyncToClosureGateway.executeThrowingAction(completion) { [weak self] in + guard let self else { throw KeyDerivationErrors.genericOtherError } + return try await self.deriveUnifiedSpendingKey(seed: seed, accountIndex: accountIndex) + } + } + + public func deriveUnifiedSpendingKey(seed: [UInt8], accountIndex: Int) -> SinglePublisher { + AsyncToCombineGateway.executeThrowingAction() { [weak self] in + guard let self else { throw KeyDerivationErrors.genericOtherError } + return try await self.deriveUnifiedSpendingKey(seed: seed, accountIndex: accountIndex) + } + } + + public func receiverTypecodesFromUnifiedAddress(_ address: UnifiedAddress) throws -> [UnifiedAddress.ReceiverTypecodes] { do { - return try DerivationTool.rustwelding.receiverTypecodesOnUnifiedAddress(address.stringEncoded) + return try backend.receiverTypecodesOnUnifiedAddress(address.stringEncoded) .map({ UnifiedAddress.ReceiverTypecodes(typecode: $0) }) } catch { throw KeyDerivationErrors.invalidUnifiedAddress @@ -121,23 +153,23 @@ public struct AddressMetadata { extension DerivationTool: KeyValidation { public func isValidUnifiedFullViewingKey(_ ufvk: String) -> Bool { - DerivationTool.rustwelding.isValidUnifiedFullViewingKey(ufvk, networkType: networkType) + backend.isValidUnifiedFullViewingKey(ufvk) } public func isValidUnifiedAddress(_ unifiedAddress: String) -> Bool { - DerivationTool.rustwelding.isValidUnifiedAddress(unifiedAddress, networkType: networkType) + backend.isValidUnifiedAddress(unifiedAddress) } public func isValidTransparentAddress(_ tAddress: String) -> Bool { - DerivationTool.rustwelding.isValidTransparentAddress(tAddress, networkType: networkType) + backend.isValidTransparentAddress(tAddress) } public func isValidSaplingAddress(_ zAddress: String) -> Bool { - DerivationTool.rustwelding.isValidSaplingAddress(zAddress, networkType: networkType) + backend.isValidSaplingAddress(zAddress) } public func isValidSaplingExtendedSpendingKey(_ extsk: String) -> Bool { - DerivationTool.rustwelding.isValidSaplingExtendedSpendingKey(extsk, networkType: networkType) + backend.isValidSaplingExtendedSpendingKey(extsk) } } @@ -166,8 +198,9 @@ extension UnifiedAddress { /// already validated by another function. only for internal use. Unless you are /// constructing an address from a primitive function of the FFI, you probably /// shouldn't be using this.. - init(validatedEncoding: String) { + init(validatedEncoding: String, networkType: NetworkType) { self.encoding = validatedEncoding + self.networkType = networkType } } @@ -206,22 +239,18 @@ public extension UnifiedSpendingKey { func map(_ transform: (UnifiedSpendingKey) throws -> T) rethrows -> T { try transform(self) } - - func deriveFullViewingKey() throws -> UnifiedFullViewingKey { - try DerivationTool(networkType: self.network).deriveUnifiedFullViewingKey(from: self) - } } public extension UnifiedAddress { /// Extracts the sapling receiver from this UA if available /// - Returns: an `Optional` func saplingReceiver() throws -> SaplingAddress { - try DerivationTool.saplingReceiver(from: self) + try DerivationTool(networkType: networkType).saplingReceiver(from: self) } /// Extracts the transparent receiver from this UA if available /// - Returns: an `Optional` func transparentReceiver() throws -> TransparentAddress { - try DerivationTool.transparentReceiver(from: self) + try DerivationTool(networkType: networkType).transparentReceiver(from: self) } } diff --git a/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift b/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift index d1ed988d..5da5d365 100644 --- a/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift +++ b/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift @@ -8,7 +8,7 @@ import Foundation class WalletTransactionEncoder: TransactionEncoder { - var rustBackend: ZcashRustBackendWelding.Type + var rustBackend: ZcashRustBackendWelding var repository: TransactionRepository let logger: Logger @@ -19,7 +19,7 @@ class WalletTransactionEncoder: TransactionEncoder { private var networkType: NetworkType init( - rust: ZcashRustBackendWelding.Type, + rustBackend: ZcashRustBackendWelding, dataDb: URL, fsBlockDbRoot: URL, repository: TransactionRepository, @@ -28,7 +28,7 @@ class WalletTransactionEncoder: TransactionEncoder { networkType: NetworkType, logger: Logger ) { - self.rustBackend = rust + self.rustBackend = rustBackend self.dataDbURL = dataDb self.fsBlockDbRoot = fsBlockDbRoot self.repository = repository @@ -40,7 +40,7 @@ class WalletTransactionEncoder: TransactionEncoder { convenience init(initializer: Initializer) { self.init( - rust: initializer.rustBackend, + rustBackend: initializer.rustBackend, dataDb: initializer.dataDbURL, fsBlockDbRoot: initializer.fsBlockDbRoot, repository: initializer.transactionRepository, @@ -84,22 +84,14 @@ class WalletTransactionEncoder: TransactionEncoder { guard ensureParams(spend: self.spendParamsURL, output: self.outputParamsURL) else { throw TransactionEncoderError.missingParams } - - let txId = await rustBackend.createToAddress( - dbData: self.dataDbURL, + + let txId = try await rustBackend.createToAddress( usk: spendingKey, to: address, value: zatoshi.amount, - memo: memoBytes, - spendParamsPath: self.spendParamsURL.path, - outputParamsPath: self.outputParamsURL.path, - networkType: networkType + memo: memoBytes ) - - guard txId > 0 else { - throw rustBackend.lastError() ?? RustWeldingError.genericError(message: "create spend failed") - } - + return Int(txId) } @@ -134,20 +126,12 @@ class WalletTransactionEncoder: TransactionEncoder { throw TransactionEncoderError.missingParams } - let txId = await rustBackend.shieldFunds( - dbData: self.dataDbURL, + let txId = try await rustBackend.shieldFunds( usk: spendingKey, memo: memo, - shieldingThreshold: shieldingThreshold, - spendParamsPath: self.spendParamsURL.path, - outputParamsPath: self.outputParamsURL.path, - networkType: networkType + shieldingThreshold: shieldingThreshold ) - - guard txId > 0 else { - throw rustBackend.lastError() ?? RustWeldingError.genericError(message: "create spend failed") - } - + return Int(txId) } diff --git a/Sources/ZcashLightClientKit/Utils/AsyncToClosureGateway.swift b/Sources/ZcashLightClientKit/Utils/AsyncToClosureGateway.swift new file mode 100644 index 00000000..78819e0f --- /dev/null +++ b/Sources/ZcashLightClientKit/Utils/AsyncToClosureGateway.swift @@ -0,0 +1,46 @@ +// +// AsyncToClosureGateway.swift +// +// +// Created by Michal Fousek on 03.04.2023. +// + +import Foundation + +enum AsyncToClosureGateway { + static func executeAction(_ completion: @escaping () -> Void, action: @escaping () async -> Void) { + Task { + await action() + completion() + } + } + + static func executeAction(_ completion: @escaping (R) -> Void, action: @escaping () async -> R) { + Task { + let result = await action() + completion(result) + } + } + + static func executeThrowingAction(_ completion: @escaping (Error?) -> Void, action: @escaping () async throws -> Void) { + Task { + do { + try await action() + completion(nil) + } catch { + completion(error) + } + } + } + + static func executeThrowingAction(_ completion: @escaping (Result) -> Void, action: @escaping () async throws -> R) { + Task { + do { + let result = try await action() + completion(.success(result)) + } catch { + completion(.failure(error)) + } + } + } +} diff --git a/Sources/ZcashLightClientKit/Utils/AsyncToCombineGateway.swift b/Sources/ZcashLightClientKit/Utils/AsyncToCombineGateway.swift new file mode 100644 index 00000000..33895634 --- /dev/null +++ b/Sources/ZcashLightClientKit/Utils/AsyncToCombineGateway.swift @@ -0,0 +1,59 @@ +// +// AsyncToCombineGateway.swift +// +// +// Created by Michal Fousek on 03.04.2023. +// + +import Combine +import Foundation + +enum AsyncToCombineGateway { + static func executeAction(action: @escaping () async -> Void) -> CompletablePublisher { + let subject = PassthroughSubject() + Task { + await action() + subject.send(completion: .finished) + } + return subject.eraseToAnyPublisher() + } + + static func executeAction(action: @escaping () async -> R) -> SinglePublisher { + let subject = PassthroughSubject() + Task { + let result = await action() + subject.send(result) + subject.send(completion: .finished) + } + return subject.eraseToAnyPublisher() + } + + static func executeThrowingAction(action: @escaping () async throws -> Void) -> CompletablePublisher { + let subject = PassthroughSubject() + Task { + do { + try await action() + subject.send(completion: .finished) + } catch { + subject.send(completion: .failure(error)) + } + } + + return subject.eraseToAnyPublisher() + } + + static func executeThrowingAction(action: @escaping () async throws -> R) -> SinglePublisher { + let subject = PassthroughSubject() + Task { + do { + let result = try await action() + subject.send(result) + subject.send(completion: .finished) + } catch { + subject.send(completion: .failure(error)) + } + } + + return subject.eraseToAnyPublisher() + } +} diff --git a/Sources/ZcashLightClientKit/Utils/SpecificCombineTypes.swift b/Sources/ZcashLightClientKit/Utils/SpecificCombineTypes.swift new file mode 100644 index 00000000..9f4a6cf8 --- /dev/null +++ b/Sources/ZcashLightClientKit/Utils/SpecificCombineTypes.swift @@ -0,0 +1,16 @@ +// +// SpecificCombineTypes.swift +// +// +// Created by Michal Fousek on 03.04.2023. +// + +import Combine +import Foundation + +/* These aliases are here to just make the API easier to read. */ + +// Publisher which emitts completed or error. No value is emitted. +public typealias CompletablePublisher = AnyPublisher +// Publisher that either emits one value and then finishes or it emits error. +public typealias SinglePublisher = AnyPublisher diff --git a/Sources/ZcashLightClientKit/Utils/SyncSessionIDGenerator.swift b/Sources/ZcashLightClientKit/Utils/SyncSessionIDGenerator.swift index 99a85127..16fde23a 100644 --- a/Sources/ZcashLightClientKit/Utils/SyncSessionIDGenerator.swift +++ b/Sources/ZcashLightClientKit/Utils/SyncSessionIDGenerator.swift @@ -7,7 +7,6 @@ import Foundation - protocol SyncSessionIDGenerator { func nextID() -> UUID } diff --git a/Tests/DarksideTests/BlockDownloaderTests.swift b/Tests/DarksideTests/BlockDownloaderTests.swift index 5157ac87..bf720fb7 100644 --- a/Tests/DarksideTests/BlockDownloaderTests.swift +++ b/Tests/DarksideTests/BlockDownloaderTests.swift @@ -13,10 +13,6 @@ import XCTest class BlockDownloaderTests: XCTestCase { let branchID = "2bb40e60" let chainName = "main" - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) let testFileManager = FileManager() var darksideWalletService: DarksideWalletService! @@ -24,16 +20,25 @@ class BlockDownloaderTests: XCTestCase { var service: LightWalletService! var storage: CompactBlockRepository! var network = DarksideWalletDNetwork() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! override func setUp() async throws { try await super.setUp() + testTempDirectory = Environment.uniqueTestTempDirectory + service = LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() + rustBackend = ZcashRustBackend.makeForTests( + fsBlockDbRoot: testTempDirectory, + networkType: network.networkType + ) + storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -58,6 +63,8 @@ class BlockDownloaderTests: XCTestCase { service = nil storage = nil downloader = nil + rustBackend = nil + testTempDirectory = nil } func testSmallDownload() async { diff --git a/Tests/DarksideTests/RewindRescanTests.swift b/Tests/DarksideTests/RewindRescanTests.swift index 59d05b85..b3920655 100644 --- a/Tests/DarksideTests/RewindRescanTests.swift +++ b/Tests/DarksideTests/RewindRescanTests.swift @@ -173,11 +173,13 @@ class RewindRescanTests: XCTestCase { // rewind to birthday let targetHeight: BlockHeight = newChaintTip - 8000 - let rewindHeight = await ZcashRustBackend.getNearestRewindHeight( - dbData: coordinator.databases.dataDB, - height: Int32(targetHeight), - networkType: network.networkType - ) + + do { + _ = try await coordinator.synchronizer.initializer.rustBackend.getNearestRewindHeight(height: Int32(targetHeight)) + } catch { + XCTFail("get nearest height failed error: \(error)") + return + } let rewindExpectation = XCTestExpectation(description: "RewindExpectation") @@ -202,11 +204,6 @@ class RewindRescanTests: XCTestCase { wait(for: [rewindExpectation], timeout: 2) - guard rewindHeight > 0 else { - XCTFail("get nearest height failed error: \(ZcashRustBackend.getLastError() ?? "null")") - return - } - // check that the balance is cleared var expectedVerifiedBalance = try await coordinator.synchronizer.getShieldedVerifiedBalance() XCTAssertEqual(initialVerifiedBalance, expectedVerifiedBalance) diff --git a/Tests/DarksideTests/SynchronizerDarksideTests.swift b/Tests/DarksideTests/SynchronizerDarksideTests.swift index 253963f7..8751fe75 100644 --- a/Tests/DarksideTests/SynchronizerDarksideTests.swift +++ b/Tests/DarksideTests/SynchronizerDarksideTests.swift @@ -27,6 +27,7 @@ class SynchronizerDarksideTests: XCTestCase { var foundTransactions: [ZcashTransaction.Overview] = [] var cancellables: [AnyCancellable] = [] var idGenerator: MockSyncSessionIDGenerator! + override func setUp() async throws { try await super.setUp() idGenerator = MockSyncSessionIDGenerator(ids: [.deadbeef]) @@ -78,7 +79,6 @@ class SynchronizerDarksideTests: XCTestCase { } func testFoundManyTransactions() async throws { - self.idGenerator.ids = [.deadbeef, .beefbeef, .beefdead] coordinator.synchronizer.eventStream .map { event in @@ -143,7 +143,7 @@ class SynchronizerDarksideTests: XCTestCase { XCTAssertEqual(self.foundTransactions.count, 2) } - func testLastStates() async throws { + func sdfstestLastStates() async throws { self.idGenerator.ids = [.deadbeef] var cancellables: [AnyCancellable] = [] @@ -474,7 +474,6 @@ class SynchronizerDarksideTests: XCTestCase { ] XCTAssertEqual(states, secondBatchOfExpectedStates) - } func testSyncAfterWipeWorks() async throws { diff --git a/Tests/DarksideTests/TransactionEnhancementTests.swift b/Tests/DarksideTests/TransactionEnhancementTests.swift index e9b463ef..cdc99c4b 100644 --- a/Tests/DarksideTests/TransactionEnhancementTests.swift +++ b/Tests/DarksideTests/TransactionEnhancementTests.swift @@ -19,16 +19,14 @@ class TransactionEnhancementTests: XCTestCase { let network = DarksideWalletDNetwork() let branchID = "2bb40e60" let chainName = "main" - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) + var testTempDirectory: URL! let testFileManager = FileManager() var initializer: Initializer! var processorConfig: CompactBlockProcessor.Configuration! var processor: CompactBlockProcessor! + var rustBackend: ZcashRustBackendWelding! var darksideWalletService: DarksideWalletService! var downloader: BlockDownloaderServiceImpl! var syncStartedExpect: XCTestExpectation! @@ -42,8 +40,8 @@ class TransactionEnhancementTests: XCTestCase { override func setUp() async throws { try await super.setUp() - - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) await InternalSyncProgress( alias: .default, @@ -63,7 +61,6 @@ class TransactionEnhancementTests: XCTestCase { waitExpectation = XCTestExpectation(description: "\(self.description) waitExpectation") - let rustBackend = ZcashRustBackend.self let birthday = Checkpoint.birthday(with: walletBirthday, network: network) let pathProvider = DefaultResourceProvider(network: network) @@ -78,27 +75,25 @@ class TransactionEnhancementTests: XCTestCase { network: network ) + rustBackend = ZcashRustBackend.makeForTests( + dbData: processorConfig.dataDb, + fsBlockDbRoot: testTempDirectory, + networkType: network.networkType + ) + try? FileManager.default.removeItem(at: processorConfig.fsBlockCacheRoot) try? FileManager.default.removeItem(at: processorConfig.dataDb) - let dbInit = try await rustBackend.initDataDb(dbData: processorConfig.dataDb, seed: nil, networkType: network.networkType) - - let ufvks = [ - try DerivationTool(networkType: network.networkType) - .deriveUnifiedSpendingKey(seed: Environment.seedBytes, accountIndex: 0) - .map { - try DerivationTool(networkType: network.networkType) - .deriveUnifiedFullViewingKey(from: $0) - } - ] + let dbInit = try await rustBackend.initDataDb(seed: nil) + + let derivationTool = DerivationTool(networkType: network.networkType) + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: Environment.seedBytes, accountIndex: 0) + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) + do { - try await rustBackend.initAccountsTable( - dbData: processorConfig.dataDb, - ufvks: ufvks, - networkType: network.networkType - ) + try await rustBackend.initAccountsTable(ufvks: [viewingKey]) } catch { - XCTFail("Failed to init accounts table error: \(String(describing: rustBackend.getLastError()))") + XCTFail("Failed to init accounts table error: \(error)") return } @@ -108,12 +103,10 @@ class TransactionEnhancementTests: XCTestCase { } _ = try await rustBackend.initBlocksTable( - dbData: processorConfig.dataDb, height: Int32(birthday.height), hash: birthday.hash, time: birthday.time, - saplingTree: birthday.saplingTree, - networkType: network.networkType + saplingTree: birthday.saplingTree ) let service = DarksideWalletService() @@ -136,7 +129,7 @@ class TransactionEnhancementTests: XCTestCase { processor = CompactBlockProcessor( service: service, storage: storage, - backend: rustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger @@ -163,6 +156,7 @@ class TransactionEnhancementTests: XCTestCase { processor = nil darksideWalletService = nil downloader = nil + testTempDirectory = nil } private func startProcessing() async throws { diff --git a/Tests/NetworkTests/BlockScanTests.swift b/Tests/NetworkTests/BlockScanTests.swift index 57fe64f6..75323999 100644 --- a/Tests/NetworkTests/BlockScanTests.swift +++ b/Tests/NetworkTests/BlockScanTests.swift @@ -15,8 +15,6 @@ import SQLite class BlockScanTests: XCTestCase { var cancelables: [AnyCancellable] = [] - let rustWelding = ZcashRustBackend.self - var dataDbURL: URL! var spendParamsURL: URL! var outputParamsURL: URL! @@ -27,25 +25,30 @@ class BlockScanTests: XCTestCase { with: 1386000, network: ZcashNetworkBuilder.network(for: .testnet) ) + + var rustBackend: ZcashRustBackendWelding! var network = ZcashNetworkBuilder.network(for: .testnet) var blockRepository: BlockRepository! - - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) + var testTempDirectory: URL! let testFileManager = FileManager() override func setUpWithError() throws { // Put setup code here. This method is called before the invocation of each test method in the class. try super.setUpWithError() - self.dataDbURL = try! __dataDbURL() - self.spendParamsURL = try! __spendParamsURL() - self.outputParamsURL = try! __outputParamsURL() + dataDbURL = try! __dataDbURL() + spendParamsURL = try! __spendParamsURL() + outputParamsURL = try! __outputParamsURL() + testTempDirectory = Environment.uniqueTestTempDirectory - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + + rustBackend = ZcashRustBackend.makeForTests( + dbData: dataDbURL, + fsBlockDbRoot: testTempDirectory, + networkType: network.networkType + ) deleteDBs() } @@ -63,25 +66,23 @@ class BlockScanTests: XCTestCase { try? testFileManager.removeItem(at: testTempDirectory) cancelables = [] blockRepository = nil + testTempDirectory = nil } func testSingleDownloadAndScan() async throws { logger = OSLogger(logLevel: .debug) - _ = try await rustWelding.initDataDb(dbData: dataDbURL, seed: nil, networkType: network.networkType) + _ = try await rustBackend.initDataDb(seed: nil) let endpoint = LightWalletEndpoint(address: "lightwalletd.testnet.electriccoin.co", port: 9067) let service = LightWalletServiceFactory(endpoint: endpoint).make() let blockCount = 100 let range = network.constants.saplingActivationHeight ... network.constants.saplingActivationHeight + blockCount - let fsDbRootURL = self.testTempDirectory - - let rustBackend = ZcashRustBackend.self let fsBlockRepository = FSCompactBlockRepository( - fsBlockDbRoot: fsDbRootURL, + fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( - fsBlockDbRoot: fsDbRootURL, + fsBlockDbRoot: testTempDirectory, rustBackend: rustBackend, logger: logger ), @@ -94,7 +95,7 @@ class BlockScanTests: XCTestCase { let processorConfig = CompactBlockProcessor.Configuration( alias: .default, - fsBlockCacheRoot: fsDbRootURL, + fsBlockCacheRoot: testTempDirectory, dataDb: dataDbURL, spendParamsURL: spendParamsURL, outputParamsURL: outputParamsURL, @@ -106,7 +107,7 @@ class BlockScanTests: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: fsBlockRepository, - backend: rustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger @@ -139,45 +140,36 @@ class BlockScanTests: XCTestCase { let metrics = SDKMetrics() metrics.enableMetrics() - guard try await self.rustWelding.initDataDb(dbData: dataDbURL, seed: nil, networkType: network.networkType) == .success else { + guard try await rustBackend.initDataDb(seed: nil) == .success else { XCTFail("Seed should not be required for this test") return } let derivationTool = DerivationTool(networkType: .testnet) - let ufvk = try derivationTool - .deriveUnifiedSpendingKey(seed: Array(seed.utf8), accountIndex: 0) - .map { try derivationTool.deriveUnifiedFullViewingKey(from: $0) } + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: Array(seed.utf8), accountIndex: 0) + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) do { - try await self.rustWelding.initAccountsTable( - dbData: self.dataDbURL, - ufvks: [ufvk], - networkType: network.networkType - ) + try await rustBackend.initAccountsTable(ufvks: [viewingKey]) } catch { - XCTFail("failed to init account table. error: \(self.rustWelding.getLastError() ?? "no error found")") + XCTFail("failed to init account table. error: \(error)") return } - try await self.rustWelding.initBlocksTable( - dbData: dataDbURL, + try await rustBackend.initBlocksTable( height: Int32(walletBirthDay.height), hash: walletBirthDay.hash, time: walletBirthDay.time, - saplingTree: walletBirthDay.saplingTree, - networkType: network.networkType + saplingTree: walletBirthDay.saplingTree ) let service = LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.eccTestnet).make() - let fsDbRootURL = self.testTempDirectory - let fsBlockRepository = FSCompactBlockRepository( - fsBlockDbRoot: fsDbRootURL, + fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( - fsBlockDbRoot: fsDbRootURL, - rustBackend: rustWelding, + fsBlockDbRoot: testTempDirectory, + rustBackend: rustBackend, logger: logger ), blockDescriptor: ZcashCompactBlockDescriptor.live, @@ -189,7 +181,7 @@ class BlockScanTests: XCTestCase { var processorConfig = CompactBlockProcessor.Configuration( alias: .default, - fsBlockCacheRoot: fsDbRootURL, + fsBlockCacheRoot: testTempDirectory, dataDb: dataDbURL, spendParamsURL: spendParamsURL, outputParamsURL: outputParamsURL, @@ -202,7 +194,7 @@ class BlockScanTests: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: fsBlockRepository, - backend: rustWelding, + rustBackend: rustBackend, config: processorConfig, metrics: metrics, logger: logger diff --git a/Tests/NetworkTests/BlockStreamingTest.swift b/Tests/NetworkTests/BlockStreamingTest.swift index 11d1c624..3de5d21a 100644 --- a/Tests/NetworkTests/BlockStreamingTest.swift +++ b/Tests/NetworkTests/BlockStreamingTest.swift @@ -10,23 +10,24 @@ import XCTest @testable import ZcashLightClientKit class BlockStreamingTest: XCTestCase { - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .testnet) logger = OSLogger(logLevel: .debug) } override func tearDownWithError() throws { try super.tearDownWithError() + rustBackend = nil try? FileManager.default.removeItem(at: __dataDbURL()) try? testFileManager.removeItem(at: testTempDirectory) + testTempDirectory = nil } func testStream() async throws { @@ -68,13 +69,11 @@ class BlockStreamingTest: XCTestCase { ) let service = LightWalletServiceFactory(endpoint: endpoint).make() - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -94,7 +93,7 @@ class BlockStreamingTest: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: realRustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger @@ -132,13 +131,11 @@ class BlockStreamingTest: XCTestCase { ) let service = LightWalletServiceFactory(endpoint: endpoint).make() - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -160,7 +157,7 @@ class BlockStreamingTest: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: realRustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger diff --git a/Tests/NetworkTests/CompactBlockProcessorTests.swift b/Tests/NetworkTests/CompactBlockProcessorTests.swift index 16c565ff..69cb0b0b 100644 --- a/Tests/NetworkTests/CompactBlockProcessorTests.swift +++ b/Tests/NetworkTests/CompactBlockProcessorTests.swift @@ -12,9 +12,30 @@ import XCTest @testable import ZcashLightClientKit class CompactBlockProcessorTests: XCTestCase { - lazy var processorConfig = { + var processorConfig: CompactBlockProcessor.Configuration! + var cancellables: [AnyCancellable] = [] + var processorEventHandler: CompactBlockProcessorEventHandler! = CompactBlockProcessorEventHandler() + var rustBackend: ZcashRustBackendWelding! + var processor: CompactBlockProcessor! + var syncStartedExpect: XCTestExpectation! + var updatedNotificationExpectation: XCTestExpectation! + var stopNotificationExpectation: XCTestExpectation! + var finishedNotificationExpectation: XCTestExpectation! + let network = ZcashNetworkBuilder.network(for: .testnet) + let mockLatestHeight = ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight + 2000 + + let testFileManager = FileManager() + var testTempDirectory: URL! + + override func setUp() async throws { + try await super.setUp() + logger = OSLogger(logLevel: .debug) + testTempDirectory = Environment.uniqueTestTempDirectory + + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + let pathProvider = DefaultResourceProvider(network: network) - return CompactBlockProcessor.Configuration( + processorConfig = CompactBlockProcessor.Configuration( alias: .default, fsBlockCacheRoot: testTempDirectory, dataDb: pathProvider.dataDbURL, @@ -24,28 +45,6 @@ class CompactBlockProcessorTests: XCTestCase { walletBirthdayProvider: { ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight }, network: ZcashNetworkBuilder.network(for: .testnet) ) - }() - - var cancellables: [AnyCancellable] = [] - var processorEventHandler: CompactBlockProcessorEventHandler! = CompactBlockProcessorEventHandler() - var processor: CompactBlockProcessor! - var syncStartedExpect: XCTestExpectation! - var updatedNotificationExpectation: XCTestExpectation! - var stopNotificationExpectation: XCTestExpectation! - var finishedNotificationExpectation: XCTestExpectation! - let network = ZcashNetworkBuilder.network(for: .testnet) - let mockLatestHeight = ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight + 2000 - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - - let testFileManager = FileManager() - - override func setUp() async throws { - try await super.setUp() - logger = OSLogger(logLevel: .debug) - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) await InternalSyncProgress( alias: .default, @@ -58,7 +57,14 @@ class CompactBlockProcessorTests: XCTestCase { latestBlockHeight: mockLatestHeight, service: liveService ) - let branchID = try ZcashRustBackend.consensusBranchIdFor(height: Int32(mockLatestHeight), networkType: network.networkType) + + rustBackend = ZcashRustBackend.makeForTests( + dbData: processorConfig.dataDb, + fsBlockDbRoot: processorConfig.fsBlockCacheRoot, + networkType: network.networkType + ) + + let branchID = try rustBackend.consensusBranchIdFor(height: Int32(mockLatestHeight)) service.mockLightDInfo = LightdInfo.with({ info in info.blockHeight = UInt64(mockLatestHeight) info.branch = "asdf" @@ -70,13 +76,11 @@ class CompactBlockProcessorTests: XCTestCase { info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) }) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: processorConfig.fsBlockCacheRoot, metadataStore: FSMetadataStore.live( fsBlockDbRoot: processorConfig.fsBlockCacheRoot, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -89,13 +93,13 @@ class CompactBlockProcessorTests: XCTestCase { processor = CompactBlockProcessor( service: service, storage: storage, - backend: realRustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger ) - let dbInit = try await realRustBackend.initDataDb(dbData: processorConfig.dataDb, seed: nil, networkType: .testnet) + let dbInit = try await rustBackend.initDataDb(seed: nil) guard case .success = dbInit else { XCTFail("Failed to initDataDb. Expected `.success` got: \(dbInit)") @@ -125,6 +129,8 @@ class CompactBlockProcessorTests: XCTestCase { cancellables = [] processor = nil processorEventHandler = nil + rustBackend = nil + testTempDirectory = nil } func processorFailed(event: CompactBlockProcessor.Event) { diff --git a/Tests/NetworkTests/CompactBlockReorgTests.swift b/Tests/NetworkTests/CompactBlockReorgTests.swift index 6bb137f4..ce428546 100644 --- a/Tests/NetworkTests/CompactBlockReorgTests.swift +++ b/Tests/NetworkTests/CompactBlockReorgTests.swift @@ -12,9 +12,31 @@ import XCTest @testable import ZcashLightClientKit class CompactBlockReorgTests: XCTestCase { - lazy var processorConfig = { + var processorConfig: CompactBlockProcessor.Configuration! + let testFileManager = FileManager() + var cancellables: [AnyCancellable] = [] + var processorEventHandler: CompactBlockProcessorEventHandler! = CompactBlockProcessorEventHandler() + var rustBackend: ZcashRustBackendWelding! + var rustBackendMockHelper: RustBackendMockHelper! + var processor: CompactBlockProcessor! + var syncStartedExpect: XCTestExpectation! + var updatedNotificationExpectation: XCTestExpectation! + var stopNotificationExpectation: XCTestExpectation! + var finishedNotificationExpectation: XCTestExpectation! + var reorgNotificationExpectation: XCTestExpectation! + let network = ZcashNetworkBuilder.network(for: .testnet) + let mockLatestHeight = ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight + 2000 + var testTempDirectory: URL! + + override func setUp() async throws { + try await super.setUp() + logger = OSLogger(logLevel: .debug) + testTempDirectory = Environment.uniqueTestTempDirectory + + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + let pathProvider = DefaultResourceProvider(network: network) - return CompactBlockProcessor.Configuration( + processorConfig = CompactBlockProcessor.Configuration( alias: .default, fsBlockCacheRoot: testTempDirectory, dataDb: pathProvider.dataDbURL, @@ -24,30 +46,6 @@ class CompactBlockReorgTests: XCTestCase { walletBirthdayProvider: { ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight }, network: ZcashNetworkBuilder.network(for: .testnet) ) - }() - - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - - let testFileManager = FileManager() - var cancellables: [AnyCancellable] = [] - var processorEventHandler: CompactBlockProcessorEventHandler! = CompactBlockProcessorEventHandler() - - var processor: CompactBlockProcessor! - var syncStartedExpect: XCTestExpectation! - var updatedNotificationExpectation: XCTestExpectation! - var stopNotificationExpectation: XCTestExpectation! - var finishedNotificationExpectation: XCTestExpectation! - var reorgNotificationExpectation: XCTestExpectation! - let network = ZcashNetworkBuilder.network(for: .testnet) - let mockLatestHeight = ZcashNetworkBuilder.network(for: .testnet).constants.saplingActivationHeight + 2000 - - override func setUp() async throws { - try await super.setUp() - logger = OSLogger(logLevel: .debug) - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) await InternalSyncProgress( alias: .default, @@ -60,8 +58,14 @@ class CompactBlockReorgTests: XCTestCase { latestBlockHeight: mockLatestHeight, service: liveService ) - - let branchID = try ZcashRustBackend.consensusBranchIdFor(height: Int32(mockLatestHeight), networkType: network.networkType) + + rustBackend = ZcashRustBackend.makeForTests( + dbData: processorConfig.dataDb, + fsBlockDbRoot: processorConfig.fsBlockCacheRoot, + networkType: network.networkType + ) + + let branchID = try rustBackend.consensusBranchIdFor(height: Int32(mockLatestHeight)) service.mockLightDInfo = LightdInfo.with { info in info.blockHeight = UInt64(mockLatestHeight) info.branch = "asdf" @@ -73,13 +77,11 @@ class CompactBlockReorgTests: XCTestCase { info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) } - let realRustBackend = ZcashRustBackend.self - let realCache = FSCompactBlockRepository( fsBlockDbRoot: processorConfig.fsBlockCacheRoot, metadataStore: FSMetadataStore.live( fsBlockDbRoot: processorConfig.fsBlockCacheRoot, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -89,21 +91,23 @@ class CompactBlockReorgTests: XCTestCase { try await realCache.create() - let initResult = try await realRustBackend.initDataDb(dbData: processorConfig.dataDb, seed: nil, networkType: .testnet) + let initResult = try await rustBackend.initDataDb(seed: nil) guard case .success = initResult else { XCTFail("initDataDb failed. Expected Success but got .seedRequired") return } - let mockBackend = MockRustBackend.self - mockBackend.mockValidateCombinedChainFailAfterAttempts = 3 - mockBackend.mockValidateCombinedChainKeepFailing = false - mockBackend.mockValidateCombinedChainFailureHeight = self.network.constants.saplingActivationHeight + 320 - + rustBackendMockHelper = await RustBackendMockHelper( + rustBackend: rustBackend, + mockValidateCombinedChainFailAfterAttempts: 3, + mockValidateCombinedChainKeepFailing: false, + mockValidateCombinedChainFailureError: .invalidChain(upperBound: Int32(network.constants.saplingActivationHeight + 320)) + ) + processor = CompactBlockProcessor( service: service, storage: realCache, - backend: mockBackend, + rustBackend: rustBackendMockHelper.rustBackendMock, config: processorConfig, metrics: SDKMetrics(), logger: logger @@ -134,6 +138,8 @@ class CompactBlockReorgTests: XCTestCase { cancellables = [] processorEventHandler = nil processor = nil + rustBackend = nil + rustBackendMockHelper = nil } func processorHandledReorg(event: CompactBlockProcessor.Event) { diff --git a/Tests/NetworkTests/DownloadTests.swift b/Tests/NetworkTests/DownloadTests.swift index 1062271d..d0093e0c 100644 --- a/Tests/NetworkTests/DownloadTests.swift +++ b/Tests/NetworkTests/DownloadTests.swift @@ -12,36 +12,31 @@ import SQLite @testable import ZcashLightClientKit class DownloadTests: XCTestCase { - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() - var network = ZcashNetworkBuilder.network(for: .testnet) + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) } override func tearDownWithError() throws { try super.tearDownWithError() try? testFileManager.removeItem(at: testTempDirectory) + testTempDirectory = nil } func testSingleDownload() async throws { let service = LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.eccTestnet).make() - - let realRustBackend = ZcashRustBackend.self + let rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: network.networkType) let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -63,7 +58,7 @@ class DownloadTests: XCTestCase { let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: realRustBackend, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger diff --git a/Tests/OfflineTests/BlockBatchValidationTests.swift b/Tests/OfflineTests/BlockBatchValidationTests.swift index afd90da6..36eca434 100644 --- a/Tests/OfflineTests/BlockBatchValidationTests.swift +++ b/Tests/OfflineTests/BlockBatchValidationTests.swift @@ -10,21 +10,22 @@ import XCTest @testable import ZcashLightClientKit class BlockBatchValidationTests: XCTestCase { - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .testnet) } override func tearDownWithError() throws { try super.tearDownWithError() try? testFileManager.removeItem(at: testTempDirectory) + rustBackend = nil + testTempDirectory = nil } func testBranchIdFailure() async throws { @@ -34,13 +35,11 @@ class BlockBatchValidationTests: XCTestCase { service: LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() ) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -76,14 +75,13 @@ class BlockBatchValidationTests: XCTestCase { info.consensusBranchID = "d34db33f" info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) service.mockLightDInfo = info - - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = Int32(0xd34d) + + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: Int32(0xd34d)) let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: mockRust, + rustBackend: mockBackend.rustBackendMock, config: config, metrics: SDKMetrics(), logger: logger @@ -109,13 +107,11 @@ class BlockBatchValidationTests: XCTestCase { service: LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() ) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -151,14 +147,13 @@ class BlockBatchValidationTests: XCTestCase { info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) service.mockLightDInfo = info - - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: mockRust, + rustBackend: mockBackend.rustBackendMock, config: config, metrics: SDKMetrics(), logger: logger @@ -184,13 +179,11 @@ class BlockBatchValidationTests: XCTestCase { service: LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() ) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -227,13 +220,12 @@ class BlockBatchValidationTests: XCTestCase { service.mockLightDInfo = info - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: mockRust, + rustBackend: mockBackend.rustBackendMock, config: config, metrics: SDKMetrics(), logger: logger @@ -259,13 +251,11 @@ class BlockBatchValidationTests: XCTestCase { service: LightWalletServiceFactory(endpoint: LightWalletEndpointBuilder.default).make() ) - let realRustBackend = ZcashRustBackend.self - let storage = FSCompactBlockRepository( fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -303,13 +293,12 @@ class BlockBatchValidationTests: XCTestCase { service.mockLightDInfo = info - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) let compactBlockProcessor = CompactBlockProcessor( service: service, storage: storage, - backend: mockRust, + rustBackend: mockBackend.rustBackendMock, config: config, metrics: SDKMetrics(), logger: logger @@ -381,9 +370,8 @@ class BlockBatchValidationTests: XCTestCase { info.saplingActivationHeight = UInt64(network.constants.saplingActivationHeight) service.mockLightDInfo = info - - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) var nextBatch: CompactBlockProcessor.NextState? do { @@ -392,7 +380,7 @@ class BlockBatchValidationTests: XCTestCase { downloaderService: downloaderService, transactionRepository: transactionRepository, config: config, - rustBackend: mockRust, + rustBackend: mockBackend.rustBackendMock, internalSyncProgress: InternalSyncProgress( alias: .default, storage: InternalSyncProgressMemoryStorage(), @@ -479,8 +467,7 @@ class BlockBatchValidationTests: XCTestCase { service.mockLightDInfo = info - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) var nextBatch: CompactBlockProcessor.NextState? do { @@ -489,7 +476,7 @@ class BlockBatchValidationTests: XCTestCase { downloaderService: downloaderService, transactionRepository: transactionRepository, config: config, - rustBackend: mockRust, + rustBackend: mockBackend.rustBackendMock, internalSyncProgress: InternalSyncProgress( alias: .default, storage: InternalSyncProgressMemoryStorage(), @@ -573,8 +560,7 @@ class BlockBatchValidationTests: XCTestCase { service.mockLightDInfo = info - let mockRust = MockRustBackend.self - mockRust.consensusBranchID = 0xd34db4d + let mockBackend = await RustBackendMockHelper(rustBackend: rustBackend, consensusBranchID: 0xd34db4d) var nextBatch: CompactBlockProcessor.NextState? do { @@ -583,7 +569,7 @@ class BlockBatchValidationTests: XCTestCase { downloaderService: downloaderService, transactionRepository: transactionRepository, config: config, - rustBackend: mockRust, + rustBackend: mockBackend.rustBackendMock, internalSyncProgress: internalSyncProgress ) diff --git a/Tests/OfflineTests/ClosureSynchronizerOfflineTests.swift b/Tests/OfflineTests/ClosureSynchronizerOfflineTests.swift index 46f28f65..5017ee57 100644 --- a/Tests/OfflineTests/ClosureSynchronizerOfflineTests.swift +++ b/Tests/OfflineTests/ClosureSynchronizerOfflineTests.swift @@ -14,7 +14,7 @@ import XCTest extension String: Error { } class ClosureSynchronizerOfflineTests: XCTestCase { - var data: AlternativeSynchronizerAPITestsData! + var data: TestsData! var cancellables: [AnyCancellable] = [] var synchronizerMock: SynchronizerMock! @@ -22,7 +22,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { override func setUpWithError() throws { try super.setUpWithError() - data = AlternativeSynchronizerAPITestsData() + data = TestsData(networkType: .testnet) synchronizerMock = SynchronizerMock() synchronizer = ClosureSDKSynchronizer(synchronizer: synchronizerMock) cancellables = [] @@ -114,17 +114,18 @@ class ClosureSynchronizerOfflineTests: XCTestCase { XCTAssertEqual(synchronizer.connectionState, .reconnecting) } - func testPrepareSucceed() throws { - synchronizerMock.prepareWithSeedViewingKeysWalletBirthdayClosure = { receivedSeed, receivedViewingKeys, receivedWalletBirthday in + func testPrepareSucceed() async throws { + let mockedViewingKey = await data.viewingKey + synchronizerMock.prepareWithViewingKeysWalletBirthdayClosure = { receivedSeed, receivedViewingKeys, receivedWalletBirthday in XCTAssertEqual(receivedSeed, self.data.seed) - XCTAssertEqual(receivedViewingKeys, [self.data.viewingKey]) + XCTAssertEqual(receivedViewingKeys, [mockedViewingKey]) XCTAssertEqual(receivedWalletBirthday, self.data.birthday) return .success } let expectation = XCTestExpectation() - synchronizer.prepare(with: data.seed, viewingKeys: [data.viewingKey], walletBirthday: data.birthday) { result in + synchronizer.prepare(with: data.seed, viewingKeys: [mockedViewingKey], walletBirthday: data.birthday) { result in switch result { case let .success(status): XCTAssertEqual(status, .success) @@ -137,14 +138,15 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testPrepareThrowsError() throws { - synchronizerMock.prepareWithSeedViewingKeysWalletBirthdayClosure = { _, _, _ in + func testPrepareThrowsError() async throws { + let mockedViewingKey = await data.viewingKey + synchronizerMock.prepareWithViewingKeysWalletBirthdayClosure = { _, _, _ in throw "Some error" } let expectation = XCTestExpectation() - synchronizer.prepare(with: data.seed, viewingKeys: [data.viewingKey], walletBirthday: data.birthday) { result in + synchronizer.prepare(with: data.seed, viewingKeys: [mockedViewingKey], walletBirthday: data.birthday) { result in switch result { case .success: XCTFail("Error should be thrown.") @@ -324,14 +326,15 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testSendToAddressSucceed() throws { + func testSendToAddressSucceed() async throws { let amount = Zatoshi(100) let recipient: Recipient = .transparent(data.transparentAddress) let memo: Memo = .text(try MemoText("Some message")) + let mockedSpendingKey = await data.spendingKey synchronizerMock .sendToAddressSpendingKeyZatoshiToAddressMemoClosure = { receivedSpendingKey, receivedZatoshi, receivedToAddress, receivedMemo in - XCTAssertEqual(receivedSpendingKey, self.data.spendingKey) + XCTAssertEqual(receivedSpendingKey, mockedSpendingKey) XCTAssertEqual(receivedZatoshi, amount) XCTAssertEqual(receivedToAddress, recipient) XCTAssertEqual(receivedMemo, memo) @@ -340,7 +343,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.sendToAddress(spendingKey: data.spendingKey, zatoshi: amount, toAddress: recipient, memo: memo) { result in + synchronizer.sendToAddress(spendingKey: mockedSpendingKey, zatoshi: amount, toAddress: recipient, memo: memo) { result in switch result { case let .success(receivedEntity): XCTAssertEqual(receivedEntity.recipient, self.data.pendingTransactionEntity.recipient) @@ -353,10 +356,11 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testSendToAddressThrowsError() throws { + func testSendToAddressThrowsError() async throws { let amount = Zatoshi(100) let recipient: Recipient = .transparent(data.transparentAddress) let memo: Memo = .text(try MemoText("Some message")) + let mockedSpendingKey = await data.spendingKey synchronizerMock.sendToAddressSpendingKeyZatoshiToAddressMemoClosure = { _, _, _, _ in throw "Some error" @@ -364,7 +368,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.sendToAddress(spendingKey: data.spendingKey, zatoshi: amount, toAddress: recipient, memo: memo) { result in + synchronizer.sendToAddress(spendingKey: mockedSpendingKey, zatoshi: amount, toAddress: recipient, memo: memo) { result in switch result { case .success: XCTFail("Error should be thrown.") @@ -376,12 +380,13 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testShieldFundsSucceed() throws { + func testShieldFundsSucceed() async throws { let memo: Memo = .text(try MemoText("Some message")) let shieldingThreshold = Zatoshi(1) + let mockedSpendingKey = await data.spendingKey synchronizerMock.shieldFundsSpendingKeyMemoShieldingThresholdClosure = { receivedSpendingKey, receivedMemo, receivedShieldingThreshold in - XCTAssertEqual(receivedSpendingKey, self.data.spendingKey) + XCTAssertEqual(receivedSpendingKey, mockedSpendingKey) XCTAssertEqual(receivedMemo, memo) XCTAssertEqual(receivedShieldingThreshold, shieldingThreshold) return self.data.pendingTransactionEntity @@ -389,7 +394,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.shieldFunds(spendingKey: data.spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) { result in + synchronizer.shieldFunds(spendingKey: mockedSpendingKey, memo: memo, shieldingThreshold: shieldingThreshold) { result in switch result { case let .success(receivedEntity): XCTAssertEqual(receivedEntity.recipient, self.data.pendingTransactionEntity.recipient) @@ -402,9 +407,10 @@ class ClosureSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testShieldFundsThrowsError() throws { + func testShieldFundsThrowsError() async throws { let memo: Memo = .text(try MemoText("Some message")) let shieldingThreshold = Zatoshi(1) + let mockedSpendingKey = await data.spendingKey synchronizerMock.shieldFundsSpendingKeyMemoShieldingThresholdClosure = { _, _, _ in throw "Some error" @@ -412,7 +418,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.shieldFunds(spendingKey: data.spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) { result in + synchronizer.shieldFunds(spendingKey: mockedSpendingKey, memo: memo, shieldingThreshold: shieldingThreshold) { result in switch result { case .success: XCTFail("Error should be thrown.") @@ -499,7 +505,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { func testGetMemosForClearedTransactionSucceed() throws { let memo: Memo = .text(try MemoText("Some message")) - synchronizerMock.getMemosForTransactionClosure = { receivedTransaction in + synchronizerMock.getMemosForClearedTransactionClosure = { receivedTransaction in XCTAssertEqual(receivedTransaction.id, self.data.clearedTransaction.id) return [memo] } @@ -521,7 +527,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testGetMemosForClearedTransactionThrowsError() { - synchronizerMock.getMemosForTransactionClosure = { _ in + synchronizerMock.getMemosForClearedTransactionClosure = { _ in throw "Some error" } @@ -664,7 +670,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testAllConfirmedTransactionsSucceed() throws { - synchronizerMock.allConfirmedTransactionsFromTransactionClosure = { receivedTransaction, limit in + synchronizerMock.allConfirmedTransactionsFromLimitClosure = { receivedTransaction, limit in XCTAssertEqual(receivedTransaction.id, self.data.clearedTransaction.id) XCTAssertEqual(limit, 3) return [self.data.clearedTransaction] @@ -687,7 +693,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testAllConfirmedTransactionsThrowsError() throws { - synchronizerMock.allConfirmedTransactionsFromTransactionClosure = { _, _ in + synchronizerMock.allConfirmedTransactionsFromLimitClosure = { _, _ in throw "Some error" } @@ -747,7 +753,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { let skippedEntity = UnspentTransactionOutputEntityMock(address: "addr2", txid: Data(), index: 1, script: Data(), valueZat: 2, height: 3) let refreshedUTXO = (inserted: [insertedEntity], skipped: [skippedEntity]) - synchronizerMock.refreshUTXOsAddressFromHeightClosure = { receivedAddress, receivedFromHeight in + synchronizerMock.refreshUTXOsAddressFromClosure = { receivedAddress, receivedFromHeight in XCTAssertEqual(receivedAddress, self.data.transparentAddress) XCTAssertEqual(receivedFromHeight, 121000) return refreshedUTXO @@ -770,7 +776,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testRefreshUTXOsThrowsError() { - synchronizerMock.refreshUTXOsAddressFromHeightClosure = { _, _ in + synchronizerMock.refreshUTXOsAddressFromClosure = { _, _ in throw "Some error" } @@ -911,7 +917,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testRewindSucceed() { - synchronizerMock.rewindPolicyClosure = { receivedPolicy in + synchronizerMock.rewindClosure = { receivedPolicy in if case .quick = receivedPolicy { } else { XCTFail("Unexpected policy \(receivedPolicy)") @@ -940,7 +946,7 @@ class ClosureSynchronizerOfflineTests: XCTestCase { } func testRewindThrowsError() { - synchronizerMock.rewindPolicyClosure = { _ in + synchronizerMock.rewindClosure = { _ in return Fail(error: "some error").eraseToAnyPublisher() } diff --git a/Tests/OfflineTests/CombineSynchronizerOfflineTests.swift b/Tests/OfflineTests/CombineSynchronizerOfflineTests.swift index faf8c7fd..0ade088e 100644 --- a/Tests/OfflineTests/CombineSynchronizerOfflineTests.swift +++ b/Tests/OfflineTests/CombineSynchronizerOfflineTests.swift @@ -12,7 +12,7 @@ import XCTest @testable import ZcashLightClientKit class CombineSynchronizerOfflineTests: XCTestCase { - var data: AlternativeSynchronizerAPITestsData! + var data: TestsData! var cancellables: [AnyCancellable] = [] var synchronizerMock: SynchronizerMock! @@ -20,7 +20,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { override func setUpWithError() throws { try super.setUpWithError() - data = AlternativeSynchronizerAPITestsData() + data = TestsData(networkType: .testnet) synchronizerMock = SynchronizerMock() synchronizer = CombineSDKSynchronizer(synchronizer: synchronizerMock) cancellables = [] @@ -112,17 +112,18 @@ class CombineSynchronizerOfflineTests: XCTestCase { XCTAssertEqual(synchronizer.connectionState, .reconnecting) } - func testPrepareSucceed() throws { - synchronizerMock.prepareWithSeedViewingKeysWalletBirthdayClosure = { receivedSeed, receivedViewingKeys, receivedWalletBirthday in + func testPrepareSucceed() async throws { + let mockedViewingKey = await self.data.viewingKey + synchronizerMock.prepareWithViewingKeysWalletBirthdayClosure = { receivedSeed, receivedViewingKeys, receivedWalletBirthday in XCTAssertEqual(receivedSeed, self.data.seed) - XCTAssertEqual(receivedViewingKeys, [self.data.viewingKey]) + XCTAssertEqual(receivedViewingKeys, [mockedViewingKey]) XCTAssertEqual(receivedWalletBirthday, self.data.birthday) return .success } let expectation = XCTestExpectation() - synchronizer.prepare(with: data.seed, viewingKeys: [data.viewingKey], walletBirthday: data.birthday) + synchronizer.prepare(with: data.seed, viewingKeys: [mockedViewingKey], walletBirthday: data.birthday) .sink( receiveCompletion: { result in switch result { @@ -141,14 +142,15 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testPrepareThrowsError() throws { - synchronizerMock.prepareWithSeedViewingKeysWalletBirthdayClosure = { _, _, _ in + func testPrepareThrowsError() async throws { + let mockedViewingKey = await self.data.viewingKey + synchronizerMock.prepareWithViewingKeysWalletBirthdayClosure = { _, _, _ in throw "Some error" } let expectation = XCTestExpectation() - synchronizer.prepare(with: data.seed, viewingKeys: [data.viewingKey], walletBirthday: data.birthday) + synchronizer.prepare(with: data.seed, viewingKeys: [mockedViewingKey], walletBirthday: data.birthday) .sink( receiveCompletion: { result in switch result { @@ -329,14 +331,15 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testSendToAddressSucceed() throws { + func testSendToAddressSucceed() async throws { let amount = Zatoshi(100) let recipient: Recipient = .transparent(data.transparentAddress) let memo: Memo = .text(try MemoText("Some message")) + let mockedSpendingKey = await data.spendingKey synchronizerMock .sendToAddressSpendingKeyZatoshiToAddressMemoClosure = { receivedSpendingKey, receivedZatoshi, receivedToAddress, receivedMemo in - XCTAssertEqual(receivedSpendingKey, self.data.spendingKey) + XCTAssertEqual(receivedSpendingKey, mockedSpendingKey) XCTAssertEqual(receivedZatoshi, amount) XCTAssertEqual(receivedToAddress, recipient) XCTAssertEqual(receivedMemo, memo) @@ -345,7 +348,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.sendToAddress(spendingKey: data.spendingKey, zatoshi: amount, toAddress: recipient, memo: memo) + synchronizer.sendToAddress(spendingKey: mockedSpendingKey, zatoshi: amount, toAddress: recipient, memo: memo) .sink( receiveCompletion: { result in switch result { @@ -364,10 +367,11 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testSendToAddressThrowsError() throws { + func testSendToAddressThrowsError() async throws { let amount = Zatoshi(100) let recipient: Recipient = .transparent(data.transparentAddress) let memo: Memo = .text(try MemoText("Some message")) + let mockedSpendingKey = await data.spendingKey synchronizerMock.sendToAddressSpendingKeyZatoshiToAddressMemoClosure = { _, _, _, _ in throw "Some error" @@ -375,7 +379,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.sendToAddress(spendingKey: data.spendingKey, zatoshi: amount, toAddress: recipient, memo: memo) + synchronizer.sendToAddress(spendingKey: mockedSpendingKey, zatoshi: amount, toAddress: recipient, memo: memo) .sink( receiveCompletion: { result in switch result { @@ -394,12 +398,13 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testShieldFundsSucceed() throws { + func testShieldFundsSucceed() async throws { let memo: Memo = .text(try MemoText("Some message")) let shieldingThreshold = Zatoshi(1) + let mockedSpendingKey = await data.spendingKey synchronizerMock.shieldFundsSpendingKeyMemoShieldingThresholdClosure = { receivedSpendingKey, receivedMemo, receivedShieldingThreshold in - XCTAssertEqual(receivedSpendingKey, self.data.spendingKey) + XCTAssertEqual(receivedSpendingKey, mockedSpendingKey) XCTAssertEqual(receivedMemo, memo) XCTAssertEqual(receivedShieldingThreshold, shieldingThreshold) return self.data.pendingTransactionEntity @@ -407,7 +412,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.shieldFunds(spendingKey: data.spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) + synchronizer.shieldFunds(spendingKey: mockedSpendingKey, memo: memo, shieldingThreshold: shieldingThreshold) .sink( receiveCompletion: { result in switch result { @@ -426,9 +431,10 @@ class CombineSynchronizerOfflineTests: XCTestCase { wait(for: [expectation], timeout: 0.5) } - func testShieldFundsThrowsError() throws { + func testShieldFundsThrowsError() async throws { let memo: Memo = .text(try MemoText("Some message")) let shieldingThreshold = Zatoshi(1) + let mockedSpendingKey = await data.spendingKey synchronizerMock.shieldFundsSpendingKeyMemoShieldingThresholdClosure = { _, _, _ in throw "Some error" @@ -436,7 +442,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let expectation = XCTestExpectation() - synchronizer.shieldFunds(spendingKey: data.spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) + synchronizer.shieldFunds(spendingKey: mockedSpendingKey, memo: memo, shieldingThreshold: shieldingThreshold) .sink( receiveCompletion: { result in switch result { @@ -581,7 +587,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { func testGetMemosForClearedTransactionSucceed() throws { let memo: Memo = .text(try MemoText("Some message")) - synchronizerMock.getMemosForTransactionClosure = { receivedTransaction in + synchronizerMock.getMemosForClearedTransactionClosure = { receivedTransaction in XCTAssertEqual(receivedTransaction.id, self.data.clearedTransaction.id) return [memo] } @@ -608,7 +614,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testGetMemosForClearedTransactionThrowsError() { - synchronizerMock.getMemosForTransactionClosure = { _ in + synchronizerMock.getMemosForClearedTransactionClosure = { _ in throw "Some error" } @@ -802,7 +808,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testAllConfirmedTransactionsSucceed() throws { - synchronizerMock.allConfirmedTransactionsFromTransactionClosure = { receivedTransaction, limit in + synchronizerMock.allConfirmedTransactionsFromLimitClosure = { receivedTransaction, limit in XCTAssertEqual(receivedTransaction.id, self.data.clearedTransaction.id) XCTAssertEqual(limit, 3) return [self.data.clearedTransaction] @@ -830,7 +836,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testAllConfirmedTransactionsThrowsError() throws { - synchronizerMock.allConfirmedTransactionsFromTransactionClosure = { _, _ in + synchronizerMock.allConfirmedTransactionsFromLimitClosure = { _, _ in throw "Some error" } @@ -910,7 +916,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { let skippedEntity = UnspentTransactionOutputEntityMock(address: "addr2", txid: Data(), index: 1, script: Data(), valueZat: 2, height: 3) let refreshedUTXO = (inserted: [insertedEntity], skipped: [skippedEntity]) - synchronizerMock.refreshUTXOsAddressFromHeightClosure = { receivedAddress, receivedFromHeight in + synchronizerMock.refreshUTXOsAddressFromClosure = { receivedAddress, receivedFromHeight in XCTAssertEqual(receivedAddress, self.data.transparentAddress) XCTAssertEqual(receivedFromHeight, 121000) return refreshedUTXO @@ -939,7 +945,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testRefreshUTXOsThrowsError() { - synchronizerMock.refreshUTXOsAddressFromHeightClosure = { _, _ in + synchronizerMock.refreshUTXOsAddressFromClosure = { _, _ in throw "Some error" } @@ -1126,7 +1132,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testRewindSucceed() { - synchronizerMock.rewindPolicyClosure = { receivedPolicy in + synchronizerMock.rewindClosure = { receivedPolicy in if case .quick = receivedPolicy { } else { XCTFail("Unexpected policy \(receivedPolicy)") @@ -1155,7 +1161,7 @@ class CombineSynchronizerOfflineTests: XCTestCase { } func testRewindThrowsError() { - synchronizerMock.rewindPolicyClosure = { _ in + synchronizerMock.rewindClosure = { _ in return Fail(error: "some error").eraseToAnyPublisher() } diff --git a/Tests/OfflineTests/CompactBlockProcessorOfflineTests.swift b/Tests/OfflineTests/CompactBlockProcessorOfflineTests.swift index 32b3ce9e..9e6a71b8 100644 --- a/Tests/OfflineTests/CompactBlockProcessorOfflineTests.swift +++ b/Tests/OfflineTests/CompactBlockProcessorOfflineTests.swift @@ -11,24 +11,22 @@ import XCTest class CompactBlockProcessorOfflineTests: XCTestCase { let testFileManager = FileManager() - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) } override func tearDownWithError() throws { try super.tearDownWithError() - try FileManager.default.removeItem(at: self.testTempDirectory) + try FileManager.default.removeItem(at: testTempDirectory) } func testComputeProcessingRangeForSingleLoop() async throws { let network = ZcashNetworkBuilder.network(for: .testnet) - let realRustBackend = ZcashRustBackend.self + let rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .testnet) let processorConfig = CompactBlockProcessor.Configuration.standard( for: network, @@ -44,7 +42,7 @@ class CompactBlockProcessorOfflineTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: realRustBackend, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -55,7 +53,7 @@ class CompactBlockProcessorOfflineTests: XCTestCase { let processor = CompactBlockProcessor( service: service, storage: storage, - backend: ZcashRustBackend.self, + rustBackend: rustBackend, config: processorConfig, metrics: SDKMetrics(), logger: logger diff --git a/Tests/OfflineTests/CompactBlockRepositoryTests.swift b/Tests/OfflineTests/CompactBlockRepositoryTests.swift index 28334f35..d65011a7 100644 --- a/Tests/OfflineTests/CompactBlockRepositoryTests.swift +++ b/Tests/OfflineTests/CompactBlockRepositoryTests.swift @@ -13,22 +13,22 @@ import XCTest class CompactBlockRepositoryTests: XCTestCase { let network = ZcashNetworkBuilder.network(for: .testnet) - - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! override func setUpWithError() throws { try super.setUpWithError() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + testTempDirectory = Environment.uniqueTestTempDirectory + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) + rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .testnet) } override func tearDownWithError() throws { try super.tearDownWithError() try? testFileManager.removeItem(at: testTempDirectory) + rustBackend = nil + testTempDirectory = nil } func testEmptyStorage() async throws { @@ -36,7 +36,7 @@ class CompactBlockRepositoryTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -55,7 +55,7 @@ class CompactBlockRepositoryTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -82,7 +82,7 @@ class CompactBlockRepositoryTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, @@ -108,7 +108,7 @@ class CompactBlockRepositoryTests: XCTestCase { fsBlockDbRoot: testTempDirectory, metadataStore: FSMetadataStore.live( fsBlockDbRoot: testTempDirectory, - rustBackend: ZcashRustBackend.self, + rustBackend: rustBackend, logger: logger ), blockDescriptor: .live, diff --git a/Tests/OfflineTests/DerivationToolTests/DerivationToolMainnetTests.swift b/Tests/OfflineTests/DerivationToolTests/DerivationToolMainnetTests.swift index a5107360..433e6151 100644 --- a/Tests/OfflineTests/DerivationToolTests/DerivationToolMainnetTests.swift +++ b/Tests/OfflineTests/DerivationToolTests/DerivationToolMainnetTests.swift @@ -16,7 +16,8 @@ class DerivationToolMainnetTests: XCTestCase { let testRecipientAddress = UnifiedAddress( validatedEncoding: """ u1l9f0l4348negsncgr9pxd9d3qaxagmqv3lnexcplmufpq7muffvfaue6ksevfvd7wrz7xrvn95rc5zjtn7ugkmgh5rnxswmcj30y0pw52pn0zjvy38rn2esfgve64rj5pcmazxgpyuj - """ + """, + networkType: .mainnet ) let expectedSpendingKey = UnifiedSpendingKey( @@ -45,82 +46,72 @@ class DerivationToolMainnetTests: XCTestCase { """) let expectedSaplingAddress = SaplingAddress(validatedEncoding: "zs1vp7kvlqr4n9gpehztr76lcn6skkss9p8keqs3nv8avkdtjrcctrvmk9a7u494kluv756jeee5k0") - - let derivationTool = DerivationTool(networkType: NetworkType.mainnet) + + let derivationTool = TestsData(networkType: .mainnet).derivationTools let expectedTransparentAddress = TransparentAddress(validatedEncoding: "t1dRJRY7GmyeykJnMH38mdQoaZtFhn1QmGz") - func testDeriveViewingKeysFromSeed() throws { + + func testDeriveViewingKeysFromSeed() async throws { let seedBytes = [UInt8](seedData) - let spendingKey = try derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: 0) - - let viewingKey = try derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: 0) + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) XCTAssertEqual(expectedViewingKey, viewingKey) } - func testDeriveViewingKeyFromSpendingKeys() throws { - XCTAssertEqual( - expectedViewingKey, - try derivationTool.deriveUnifiedFullViewingKey(from: expectedSpendingKey) - ) + func testDeriveViewingKeyFromSpendingKeys() async throws { + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: expectedSpendingKey) + XCTAssertEqual(expectedViewingKey, viewingKey) } - func testDeriveSpendingKeysFromSeed() throws { + func testDeriveSpendingKeysFromSeed() async throws { let seedBytes = [UInt8](seedData) - let spendingKey = try derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: 0) + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: 0) XCTAssertEqual(expectedSpendingKey, spendingKey) } - func testDeriveUnifiedSpendingKeyFromSeed() throws { + func testDeriveUnifiedSpendingKeyFromSeed() async throws { let account = 0 let seedBytes = [UInt8](seedData) - XCTAssertNoThrow(try derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: account)) + _ = try await derivationTool.deriveUnifiedSpendingKey(seed: seedBytes, accountIndex: account) } func testGetTransparentAddressFromUA() throws { XCTAssertEqual( - try DerivationTool.transparentReceiver(from: testRecipientAddress), + try DerivationTool(networkType: .mainnet).transparentReceiver(from: testRecipientAddress), expectedTransparentAddress ) } func testIsValidViewingKey() { XCTAssertTrue( - DerivationTool.rustwelding.isValidSaplingExtendedFullViewingKey( + ZcashKeyDerivationBackend(networkType: .mainnet).isValidSaplingExtendedFullViewingKey( """ zxviews1q0dm7hkzqqqqpqplzv3f50rl4vay8uy5zg9e92f62lqg6gzu63rljety32xy5tcyenzuu3n386ws772nm6tp4sads8n37gff6nxmyz8dn9keehmapk0spc6pzx5ux\ epgu52xnwzxxnuja5tv465t9asppnj3eqncu3s7g3gzg5x8ss4ypkw08xwwyj7ky5skvnd9ldwj2u8fz2ry94s5q8p9lyp3j96yckudmp087d2jr2rnfuvjp7f56v78vpe658\ vljjddj7s645q399jd7 - """, - networkType: .mainnet + """ ) ) XCTAssertFalse( - DerivationTool.rustwelding.isValidSaplingExtendedFullViewingKey( - "zxviews1q0dm7hkzky5skvnd9ldwj2u8fz2ry94s5q8p9lyp3j96yckudmp087d2jr2rnfuvjp7f56v78vpe658vljjddj7s645q399jd7", - networkType: .mainnet + ZcashKeyDerivationBackend(networkType: .mainnet).isValidSaplingExtendedFullViewingKey( + "zxviews1q0dm7hkzky5skvnd9ldwj2u8fz2ry94s5q8p9lyp3j96yckudmp087d2jr2rnfuvjp7f56v78vpe658vljjddj7s645q399jd7" ) ) } - func testDeriveQuiteALotOfUnifiedKeysFromSeed() throws { + func testDeriveQuiteALotOfUnifiedKeysFromSeed() async throws { let numberOfAccounts: Int = 10 - let ufvks = try (0 ..< numberOfAccounts) - .map({ - try derivationTool.deriveUnifiedSpendingKey( - seed: [UInt8](seedData), - accountIndex: $0 - ) - }) - .map { - try derivationTool.deriveUnifiedFullViewingKey( - from: $0 - ) - } + var ufvks: [UnifiedFullViewingKey] = [] + for i in 0..(typecodes), @@ -70,7 +70,8 @@ final class UnifiedTypecodesTests: XCTestCase { validatedEncoding: """ u1l9f0l4348negsncgr9pxd9d3qaxagmqv3lnexcplmufpq7muffvfaue6ksevfvd7wrz7xrvn95rc5zjtn7ugkmgh5rnxswmcj30y0pw52pn0zjvy38rn2esfgve64rj5pcmazxg\ pyuj - """ + """, + networkType: .testnet ) XCTAssertEqual(try ua.availableReceiverTypecodes(), [.sapling, .p2pkh]) diff --git a/Tests/OfflineTests/WalletTests.swift b/Tests/OfflineTests/WalletTests.swift index 8625378f..c204a433 100644 --- a/Tests/OfflineTests/WalletTests.swift +++ b/Tests/OfflineTests/WalletTests.swift @@ -12,13 +12,8 @@ import XCTest @testable import ZcashLightClientKit class WalletTests: XCTestCase { - let testTempDirectory = URL(fileURLWithPath: NSString( - string: NSTemporaryDirectory() - ) - .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) - let testFileManager = FileManager() - + var testTempDirectory: URL! var dbData: URL! = nil var paramDestination: URL! = nil var network = ZcashNetworkBuilder.network(for: .testnet) @@ -26,8 +21,9 @@ class WalletTests: XCTestCase { override func setUpWithError() throws { try super.setUpWithError() + testTempDirectory = Environment.uniqueTestTempDirectory dbData = try __dataDbURL() - try self.testFileManager.createDirectory(at: self.testTempDirectory, withIntermediateDirectories: false) + try self.testFileManager.createDirectory(at: testTempDirectory, withIntermediateDirectories: false) paramDestination = try __documentsDirectory().appendingPathComponent("parameters") } @@ -36,17 +32,17 @@ class WalletTests: XCTestCase { if testFileManager.fileExists(atPath: dbData.absoluteString) { try testFileManager.trashItem(at: dbData, resultingItemURL: nil) } - try? self.testFileManager.removeItem(at: self.testTempDirectory) + try? self.testFileManager.removeItem(at: testTempDirectory) } func testWalletInitialization() async throws { - let derivationTool = DerivationTool(networkType: network.networkType) - let ufvk = try derivationTool.deriveUnifiedSpendingKey(seed: seedData.bytes, accountIndex: 0) - .map({ try derivationTool.deriveUnifiedFullViewingKey(from: $0) }) + let derivationTool = TestsData(networkType: network.networkType).derivationTools + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey(seed: seedData.bytes, accountIndex: 0) + let viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) let wallet = Initializer( cacheDbURL: nil, - fsBlockDbRoot: self.testTempDirectory, + fsBlockDbRoot: testTempDirectory, dataDbURL: try __dataDbURL(), pendingDbURL: try TestDbBuilder.pendingTransactionsDbURL(), endpoint: LightWalletEndpointBuilder.default, @@ -58,7 +54,7 @@ class WalletTests: XCTestCase { let synchronizer = SDKSynchronizer(initializer: wallet) do { - guard case .success = try await synchronizer.prepare(with: seedData.bytes, viewingKeys: [ufvk], walletBirthday: 663194) else { + guard case .success = try await synchronizer.prepare(with: seedData.bytes, viewingKeys: [viewingKey], walletBirthday: 663194) else { XCTFail("Failed to initDataDb. Expected `.success` got: `.seedRequired`") return } diff --git a/Tests/OfflineTests/ZcashRustBackendTests.swift b/Tests/OfflineTests/ZcashRustBackendTests.swift index a2deca94..1304459c 100644 --- a/Tests/OfflineTests/ZcashRustBackendTests.swift +++ b/Tests/OfflineTests/ZcashRustBackendTests.swift @@ -12,8 +12,8 @@ import XCTest class ZcashRustBackendTests: XCTestCase { var dbData: URL! + var rustBackend: ZcashRustBackendWelding! var dataDbHandle = TestDbHandle(originalDb: TestDbBuilder.prePopulatedDataDbURL()!) - let spendingKey = """ secret-extended-key-test1qvpevftsqqqqpqy52ut2vv24a2qh7nsukew7qg9pq6djfwyc3xt5vaxuenshp2hhspp9qmqvdh0gs2ljpwxders5jkwgyhgln0drjqaguaenfhehz4esdl4k\ wlm5t9q0l6wmzcrvcf5ed6dqzvct3e2ge7f6qdvzhp02m7sp5a0qjssrwpdh7u6tq89hl3wchuq8ljq8r8rwd6xdwh3nry9at80z7amnj3s6ah4jevnvfr08gxpws523z95g6dmn4wm6l3658\ @@ -24,23 +24,27 @@ class ZcashRustBackendTests: XCTestCase { let zpend: Int = 500_000 let networkType = NetworkType.testnet - + override func setUp() { super.setUp() + dbData = try! __dataDbURL() try? dataDbHandle.setUp() + + rustBackend = ZcashRustBackend.makeForTests(dbData: dbData, fsBlockDbRoot: Environment.uniqueTestTempDirectory, networkType: .testnet) } override func tearDown() { super.tearDown() try? FileManager.default.removeItem(at: dbData!) dataDbHandle.dispose() + rustBackend = nil } func testInitWithShortSeedAndFail() async throws { let seed = "testreferencealice" - let dbInit = try await ZcashRustBackend.initDataDb(dbData: self.dbData!, seed: nil, networkType: self.networkType) + let dbInit = try await rustBackend.initDataDb(seed: nil) guard case .success = dbInit else { XCTFail("Failed to initDataDb. Expected `.success` got: \(String(describing: dbInit))") @@ -48,76 +52,64 @@ class ZcashRustBackendTests: XCTestCase { } do { - _ = try await ZcashRustBackend.createAccount(dbData: dbData!, seed: Array(seed.utf8), networkType: networkType) + _ = try await rustBackend.createAccount(seed: Array(seed.utf8)) XCTFail("createAccount should fail here.") } catch { } } func testIsValidTransparentAddressFalse() { XCTAssertFalse( - ZcashRustBackend.isValidTransparentAddress( - "ztestsapling12k9m98wmpjts2m56wc60qzhgsfvlpxcwah268xk5yz4h942sd58jy3jamqyxjwums6hw7kfa4cc", - networkType: networkType + ZcashKeyDerivationBackend(networkType: networkType).isValidTransparentAddress( + "ztestsapling12k9m98wmpjts2m56wc60qzhgsfvlpxcwah268xk5yz4h942sd58jy3jamqyxjwums6hw7kfa4cc" ) ) } func testIsValidTransparentAddressTrue() { XCTAssertTrue( - ZcashRustBackend.isValidTransparentAddress( - "tmSwpioc7reeoNrYB9SKpWkurJz3yEj3ee7", - networkType: networkType + ZcashKeyDerivationBackend(networkType: networkType).isValidTransparentAddress( + "tmSwpioc7reeoNrYB9SKpWkurJz3yEj3ee7" ) ) } func testIsValidSaplingAddressTrue() { XCTAssertTrue( - ZcashRustBackend.isValidSaplingAddress( - "ztestsapling12k9m98wmpjts2m56wc60qzhgsfvlpxcwah268xk5yz4h942sd58jy3jamqyxjwums6hw7kfa4cc", - networkType: networkType + ZcashKeyDerivationBackend(networkType: networkType).isValidSaplingAddress( + "ztestsapling12k9m98wmpjts2m56wc60qzhgsfvlpxcwah268xk5yz4h942sd58jy3jamqyxjwums6hw7kfa4cc" ) ) } func testIsValidSaplingAddressFalse() { XCTAssertFalse( - ZcashRustBackend.isValidSaplingAddress( - "tmSwpioc7reeoNrYB9SKpWkurJz3yEj3ee7", - networkType: networkType + ZcashKeyDerivationBackend(networkType: networkType).isValidSaplingAddress( + "tmSwpioc7reeoNrYB9SKpWkurJz3yEj3ee7" ) ) } func testListTransparentReceivers() async throws { let testVector = [TestVector](TestVector.testVectors![0 ... 2]) - let network = NetworkType.mainnet let tempDBs = TemporaryDbBuilder.build() let seed = testVector[0].root_seed! + rustBackend = ZcashRustBackend.makeForTests(dbData: tempDBs.dataDB, fsBlockDbRoot: Environment.uniqueTestTempDirectory, networkType: .mainnet) try? FileManager.default.removeItem(at: tempDBs.dataDB) - let initResult = try await ZcashRustBackend.initDataDb( - dbData: tempDBs.dataDB, - seed: seed, - networkType: network - ) + let initResult = try await rustBackend.initDataDb(seed: seed) XCTAssertEqual(initResult, .success) - let usk = try await ZcashRustBackend.createAccount( - dbData: tempDBs.dataDB, - seed: seed, - networkType: network - ) + let usk = try await rustBackend.createAccount(seed: seed) XCTAssertEqual(usk.account, 0) let expectedReceivers = try testVector.map { - UnifiedAddress(validatedEncoding: $0.unified_addr!) + UnifiedAddress(validatedEncoding: $0.unified_addr!, networkType: .mainnet) } .map { try $0.transparentReceiver() } let expectedUAs = testVector.map { - UnifiedAddress(validatedEncoding: $0.unified_addr!) + UnifiedAddress(validatedEncoding: $0.unified_addr!, networkType: .mainnet) } guard expectedReceivers.count >= 2 else { @@ -127,19 +119,11 @@ class ZcashRustBackendTests: XCTestCase { var uAddresses: [UnifiedAddress] = [] for i in 0...2 { uAddresses.append( - try await ZcashRustBackend.getCurrentAddress( - dbData: tempDBs.dataDB, - account: 0, - networkType: network - ) + try await rustBackend.getCurrentAddress(account: 0) ) if i < 2 { - _ = try await ZcashRustBackend.getNextAvailableAddress( - dbData: tempDBs.dataDB, - account: 0, - networkType: network - ) + _ = try await rustBackend.getNextAvailableAddress(account: 0) } } @@ -148,11 +132,7 @@ class ZcashRustBackendTests: XCTestCase { expectedUAs ) - let actualReceivers = try await ZcashRustBackend.listTransparentReceivers( - dbData: tempDBs.dataDB, - account: 0, - networkType: network - ) + let actualReceivers = try await rustBackend.listTransparentReceivers(account: 0) XCTAssertEqual( expectedReceivers.sorted(), @@ -163,7 +143,7 @@ class ZcashRustBackendTests: XCTestCase { func testGetMetadataFromAddress() throws { let recipientAddress = "zs17mg40levjezevuhdp5pqrd52zere7r7vrjgdwn5sj4xsqtm20euwahv9anxmwr3y3kmwuz8k55a" - let metadata = ZcashRustBackend.getAddressMetadata(recipientAddress) + let metadata = ZcashKeyDerivationBackend.getAddressMetadata(recipientAddress) XCTAssertEqual(metadata?.networkType, .mainnet) XCTAssertEqual(metadata?.addressType, .sapling) diff --git a/Tests/PerformanceTests/SynchronizerTests.swift b/Tests/PerformanceTests/SynchronizerTests.swift index abbdc189..c659828d 100644 --- a/Tests/PerformanceTests/SynchronizerTests.swift +++ b/Tests/PerformanceTests/SynchronizerTests.swift @@ -26,6 +26,8 @@ class SynchronizerTests: XCTestCase { var coordinator: TestCoordinator! var cancellables: [AnyCancellable] = [] var sdkSynchronizerSyncStatusHandler: SDKSynchronizerSyncStatusHandler! = SDKSynchronizerSyncStatusHandler() + var rustBackend: ZcashRustBackendWelding! + var testTempDirectory: URL! let seedPhrase = """ wish puppy smile loan doll curve hole maze file ginger hair nose key relax knife witness cannon grab despair throw review deal slush frame @@ -33,11 +35,19 @@ class SynchronizerTests: XCTestCase { var birthday: BlockHeight = 1_730_000 + override func setUp() async throws { + try await super.setUp() + testTempDirectory = Environment.uniqueTestTempDirectory + rustBackend = ZcashRustBackend.makeForTests(fsBlockDbRoot: testTempDirectory, networkType: .mainnet) + } + override func tearDown() { super.tearDown() coordinator = nil cancellables = [] sdkSynchronizerSyncStatusHandler = nil + rustBackend = nil + testTempDirectory = nil } func testHundredBlocksSync() async throws { @@ -47,11 +57,11 @@ class SynchronizerTests: XCTestCase { return } let seedBytes = [UInt8](seedData) - let spendingKey = try derivationTool.deriveUnifiedSpendingKey( + let spendingKey = try await derivationTool.deriveUnifiedSpendingKey( seed: seedBytes, accountIndex: 0 ) - let ufvk = try derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) + let ufvk = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) let network = ZcashNetworkBuilder.network(for: .mainnet) let endpoint = LightWalletEndpoint(address: "lightwalletd.electriccoin.co", port: 9067, secure: true) diff --git a/Tests/TestUtils/CompactBlockProcessorEventHandler.swift b/Tests/TestUtils/CompactBlockProcessorEventHandler.swift index 24644fc0..3e1a5920 100644 --- a/Tests/TestUtils/CompactBlockProcessorEventHandler.swift +++ b/Tests/TestUtils/CompactBlockProcessorEventHandler.swift @@ -29,6 +29,7 @@ class CompactBlockProcessorEventHandler { func subscribe(to blockProcessor: CompactBlockProcessor, expectations: [EventIdentifier: XCTestExpectation]) async { let closure: CompactBlockProcessor.EventClosure = { event in + print("Received event: \(event.identifier)") expectations[event.identifier]?.fulfill() } diff --git a/Tests/TestUtils/Sourcery/AutoMockable.stencil b/Tests/TestUtils/Sourcery/AutoMockable.stencil new file mode 100644 index 00000000..f14be10b --- /dev/null +++ b/Tests/TestUtils/Sourcery/AutoMockable.stencil @@ -0,0 +1,165 @@ +import Combine +@testable import ZcashLightClientKit + +{% macro methodName method%}{%if method|annotated:"mockedName" %}{{ method.annotations.mockedName }}{% else %}{% call swiftifyMethodName method.selectorName %}{% endif %}{% endmacro %} +{% macro swiftifyMethodName name %}{{ name | replace:"(","_" | replace:")","" | replace:":","_" | replace:"`","" | snakeToCamelCase | lowerFirstWord }}{% endmacro %} +{% macro methodNameUpper method%}{%if method|annotated:"mockedName" %}{{ method.annotations.mockedName }}{% else %}{% call swiftifyMethodNameUpper method.selectorName %}{% endif %}{% endmacro %} +{% macro swiftifyMethodNameUpper name %}{{ name | replace:"(","_" | replace:")","" | replace:":","_" | replace:"`","" | snakeToCamelCase | upperFirstLetter }}{% endmacro %} +{% macro methodThrowableErrorDeclaration method type %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}ThrowableError: Error? + {% call methodMockPropertySetter method type "ThrowableError" "Error?" %} +{% endmacro %} +{% macro methodMockPropertySetter method type postfix propertyType %} + {% if type|annotated:"mockActor" %}{% if not method.isStatic %} + func set{% call methodNameUpper method %}{{ postfix }}(_ param: {{ propertyType }}) async { + {% call methodName method %}{{ postfix }} = param + } + {% endif %}{% endif %} +{% endmacro %} +{% macro methodThrowableErrorUsage method %} + if let error = {% if method.isStatic %}Self.{% endif %}{% call methodName method %}ThrowableError { + throw error + } +{% endmacro %} +{% macro methodReceivedParameters method %} + {%if method.parameters.count == 1 %} + {% if method.isStatic %}Self.{% endif %}{% call methodName method %}Received{% for param in method.parameters %}{{ param.name|upperFirstLetter }} = {{ param.name }}{% endfor %} + {% else %} + {% if not method.parameters.count == 0 %} + {% if method.isStatic %}Self.{% endif %}{% call methodName method %}ReceivedArguments = ({% for param in method.parameters %}{{ param.name }}: {{ param.name }}{% if not forloop.last%}, {% endif %}{% endfor %}) + {% endif %} + {% endif %} +{% endmacro %} +{% macro methodClosureName method %}{% call methodName method %}Closure{% endmacro %} +{% macro paramTypeName param, method %}{% if method.annotations[param.name] %}{{method.annotations[param.name]}}{% else %}{{ param.typeName }}{% endif %}{% endmacro %} +{% macro unwrappedParamTypeName param, method %}{% if method.annotations[param.name] %}{{method.annotations[param.name]}}{% else %}{{ param.typeName.unwrappedTypeName }}{% endif %}{% endmacro %} +{% macro closureType method type %}({% for param in method.parameters %}{% call paramTypeName param, method %}{% if not forloop.last %}, {% endif %}{% endfor %}) {% if method.isAsync %}async {% endif %}{% if method.throws %}throws {% endif %}-> {% if method.isInitializer %}Void{% else %}{{ method.returnTypeName }}{% endif %}{% endmacro %} +{% macro methodClosureDeclaration method type %} + {% if method.isStatic %}static {% endif %}var {% call methodClosureName method %}: ({% call closureType method type %})? + {% if type|annotated:"mockActor" %}{% if not method.isStatic %} + func set{% call methodNameUpper method %}Closure(_ param: ({% call closureType method type %})?) async { + {% call methodName method %}Closure = param + } + {% endif %}{% endif %} +{% endmacro %} +{% macro methodClosureCallParameters method %}{% for param in method.parameters %}{{ param.name }}{% if not forloop.last %}, {% endif %}{% endfor %}{% endmacro %} +{% macro mockMethod method type %} +{% if method|!annotated:"skipAutoMock" %} + // MARK: - {{ method.shortName }} + + {% if ((type|annotated:"mockActor") and (method.isAsync) or (method.isStatic)) or (not type|annotated:"mockActor") %} + {% if method.throws %} + {% call methodThrowableErrorDeclaration method type %} + {% endif %} + {% if not method.isInitializer %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}CallsCount = 0 + {% if method.isStatic %}static {% endif %}var {% call methodName method %}Called: Bool { + return {% if method.isStatic %}Self.{% endif %}{% call methodName method %}CallsCount > 0 + } + {% endif %} + {% if method.parameters.count == 1 %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}Received{% for param in method.parameters %}{{ param.name|upperFirstLetter }}: {% if param.isClosure %}({% endif %}{% call unwrappedParamTypeName param, method %}{% if param.isClosure %}){% endif %}?{% endfor %} + {% else %}{% if not method.parameters.count == 0 %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}ReceivedArguments: ({% for param in method.parameters %}{{ param.name }}: {% if param.typeAttributes.escaping %}{% call unwrappedParamTypeName param, method %}{% else %}{% call paramTypeName param, method %}{% endif %}{% if not forloop.last %}, {% endif %}{% endfor %})? + {% endif %}{% endif %} + {% if not method.returnTypeName.isVoid and not method.isInitializer %} + {% if method.isStatic %}static {% endif %}var {% call methodName method %}ReturnValue: {{ method.returnTypeName }}{{ '!' if not method.isOptionalReturnType }} + {% call methodMockPropertySetter method type "ReturnValue" method.returnTypeName %} + {% endif %} + {% call methodClosureDeclaration method type %} + {% endif %} + +{% if method.isInitializer %} + required {{ method.name }} { + {% call methodReceivedParameters method %} + {% call methodClosureName method %}?({% call methodClosureCallParameters method %}) + } +{% else %} + {% if (not method.isAsync) and (not method.isStatic) and (type|annotated:"mockActor") %}nonisolated {% endif %}{% if method.isStatic %}static {% endif %}func {{ method.name }}{% if method.isAsync %} async{% endif %}{% if method.throws %} throws{% endif %}{% if not method.returnTypeName.isVoid %} -> {{ method.returnTypeName }}{% endif %} { + {% if ((type|annotated:"mockActor") and ((method.isAsync) or (method.isStatic))) or (not type|annotated:"mockActor") %} + {% if method.throws %} + {% call methodThrowableErrorUsage method %} + {% endif %} + {% if method.isStatic %}Self.{% endif %}{% call methodName method %}CallsCount += 1 + {% call methodReceivedParameters method %} + {% if method.returnTypeName.isVoid %} + {% if method.throws %}try {% endif %}{% if method.isAsync %}await {% endif %}{% call methodClosureName method %}?({% call methodClosureCallParameters method %}) + {% else %} + if let closure = {% if method.isStatic %}Self.{% endif %}{% call methodClosureName method %} { + return {% if method.throws %}try {% endif %}{% if method.isAsync %}await {% endif %}closure({% call methodClosureCallParameters method %}) + } else { + return {% if method.isStatic %}Self.{% endif %}{% call methodName method %}ReturnValue + } + {% endif %} + {% else %} + {% if method.throws %}try {% endif %}{% call methodClosureName method %}!({% call methodClosureCallParameters method %}) + {% endif %} + } + +{% endif %} +{% endif %} +{% endmacro %} +{% macro mockOptionalVariable variable %} + var {% call mockedVariableName variable %}: {{ variable.typeName }} +{% endmacro %} +{% macro mockNonOptionalArrayOrDictionaryVariable variable %} + var {% call mockedVariableName variable %}: {{ variable.typeName }} { + get{% if variable.isAsync %} async{% endif %} { return {% call underlyingMockedVariableName variable %} } + } + var {% call underlyingMockedVariableName variable %}: {{ variable.typeName }} = {% if variable.isArray %}[]{% elif variable.isDictionary %}[:]{% endif %} +{% endmacro %} +{% macro mockNonOptionalVariable variable %} + var {% call mockedVariableName variable %}: {{ variable.typeName }} { + get { return {% call underlyingMockedVariableName variable %} } + } + var {% call underlyingMockedVariableName variable %}: {% if variable.typeName.isClosure %}({{ variable.typeName }})!{% else %}{{ variable.typeName }}!{% endif %} +{% endmacro %} + +{% macro underlyingMockedVariableName variable %}underlying{{ variable.name|upperFirstLetter }}{% endmacro %} +{% macro initialMockedVariableValue variable %}initial{{ variable.name|upperFirstLetter }}{% endmacro %} +{% macro mockedVariableName variable %}{{ variable.name }}{% endmacro %} + +// MARK: - AutoMockable protocols +{% for type in types.protocols where type.based.AutoMockable or type|annotated:"AutoMockable" %}{% if type.name != "AutoMockable" %} +{% if type|annotated:"moduleName" %} +/// Imported from {{ type.annotations.moduleName }} module +{% endif %} +{% if type|annotated:"targetOS" %} +#if os({{ type.annotations.targetOS }}) +{% endif %} +{% if type|annotated:"mockActor" %}actor {% else %}class {% endif %}{{ type.name }}Mock: {% if type|annotated:"baseClass" %}{{ type.annotations.baseClass }}, {% endif %}{% if type|annotated:"moduleName" %}{{ type.annotations.moduleName }}.{% endif %}{{ type.name }} { + +{% for method in type.allMethods|!definedInExtension %} + {% if (not method.isAsync) and (not method.isStatic) and (type|annotated:"mockActor") %} + nonisolated let {% call methodName method %}Closure: ({% call closureType method type %})? + {% endif %} +{% endfor %} + + init( +{% for method in type.allMethods|!definedInExtension where ((not method.isAsync) and (not method.isStatic) and (type|annotated:"mockActor")) %} + {% call methodName method %}Closure: ({% call closureType method type %})? = nil{% if not forloop.last %},{% endif %} +{% endfor %} + ) { + {% for method in type.allMethods|!definedInExtension %} + {% if (not method.isAsync) and (not method.isStatic) and (type|annotated:"mockActor") %} + self.{% call methodName method %}Closure = {% call methodName method %}Closure + {% endif %} + {% endfor %} + } +{% for variable in type.allVariables|!definedInExtension %} +{% if variable|!annotated:"skipAutoMock" %} + {% if variable.isOptional %}{% call mockOptionalVariable variable %} + {% elif variable.isArray or variable.isDictionary %}{% call mockNonOptionalArrayOrDictionaryVariable variable %} + {% else %}{% call mockNonOptionalVariable variable %} + {% endif %} +{% endif %} +{% endfor %} + +{% for method in type.allMethods|!definedInExtension %} + {% call mockMethod method type %} +{% endfor %} +} +{% if type|annotated:"targetOS" %} +#endif +{% endif %} +{% endif %}{% endfor %} diff --git a/Tests/TestUtils/Sourcery/AutoMockable.swift b/Tests/TestUtils/Sourcery/AutoMockable.swift new file mode 100644 index 00000000..c1b6e08d --- /dev/null +++ b/Tests/TestUtils/Sourcery/AutoMockable.swift @@ -0,0 +1,18 @@ +// +// AutoMockable.swift +// +// +// Created by Michal Fousek on 04.04.2023. +// + +/// This file defines types for which we need to generate mocks for usage in Player tests. +/// Each type must appear in appropriate section according to which module it comes from. + +// sourcery:begin: AutoMockable + +@testable import ZcashLightClientKit + +extension ZcashRustBackendWelding { } +extension Synchronizer { } + +// sourcery:end: diff --git a/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift b/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift new file mode 100644 index 00000000..60f60161 --- /dev/null +++ b/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift @@ -0,0 +1,1302 @@ +// Generated using Sourcery 2.0.2 — https://github.com/krzysztofzablocki/Sourcery +// DO NOT EDIT +import Combine +@testable import ZcashLightClientKit + + + +// MARK: - AutoMockable protocols +class SynchronizerMock: Synchronizer { + + + init( + ) { + } + var alias: ZcashSynchronizerAlias { + get { return underlyingAlias } + } + var underlyingAlias: ZcashSynchronizerAlias! + var latestState: SynchronizerState { + get { return underlyingLatestState } + } + var underlyingLatestState: SynchronizerState! + var connectionState: ConnectionState { + get { return underlyingConnectionState } + } + var underlyingConnectionState: ConnectionState! + var stateStream: AnyPublisher { + get { return underlyingStateStream } + } + var underlyingStateStream: AnyPublisher! + var eventStream: AnyPublisher { + get { return underlyingEventStream } + } + var underlyingEventStream: AnyPublisher! + var metrics: SDKMetrics { + get { return underlyingMetrics } + } + var underlyingMetrics: SDKMetrics! + var pendingTransactions: [PendingTransactionEntity] { + get async { return underlyingPendingTransactions } + } + var underlyingPendingTransactions: [PendingTransactionEntity] = [] + var clearedTransactions: [ZcashTransaction.Overview] { + get async { return underlyingClearedTransactions } + } + var underlyingClearedTransactions: [ZcashTransaction.Overview] = [] + var sentTransactions: [ZcashTransaction.Sent] { + get async { return underlyingSentTransactions } + } + var underlyingSentTransactions: [ZcashTransaction.Sent] = [] + var receivedTransactions: [ZcashTransaction.Received] { + get async { return underlyingReceivedTransactions } + } + var underlyingReceivedTransactions: [ZcashTransaction.Received] = [] + + // MARK: - prepare + + var prepareWithViewingKeysWalletBirthdayThrowableError: Error? + var prepareWithViewingKeysWalletBirthdayCallsCount = 0 + var prepareWithViewingKeysWalletBirthdayCalled: Bool { + return prepareWithViewingKeysWalletBirthdayCallsCount > 0 + } + var prepareWithViewingKeysWalletBirthdayReceivedArguments: (seed: [UInt8]?, viewingKeys: [UnifiedFullViewingKey], walletBirthday: BlockHeight)? + var prepareWithViewingKeysWalletBirthdayReturnValue: Initializer.InitializationResult! + var prepareWithViewingKeysWalletBirthdayClosure: (([UInt8]?, [UnifiedFullViewingKey], BlockHeight) async throws -> Initializer.InitializationResult)? + + func prepare(with seed: [UInt8]?, viewingKeys: [UnifiedFullViewingKey], walletBirthday: BlockHeight) async throws -> Initializer.InitializationResult { + if let error = prepareWithViewingKeysWalletBirthdayThrowableError { + throw error + } + prepareWithViewingKeysWalletBirthdayCallsCount += 1 + prepareWithViewingKeysWalletBirthdayReceivedArguments = (seed: seed, viewingKeys: viewingKeys, walletBirthday: walletBirthday) + if let closure = prepareWithViewingKeysWalletBirthdayClosure { + return try await closure(seed, viewingKeys, walletBirthday) + } else { + return prepareWithViewingKeysWalletBirthdayReturnValue + } + } + + // MARK: - start + + var startRetryThrowableError: Error? + var startRetryCallsCount = 0 + var startRetryCalled: Bool { + return startRetryCallsCount > 0 + } + var startRetryReceivedRetry: Bool? + var startRetryClosure: ((Bool) async throws -> Void)? + + func start(retry: Bool) async throws { + if let error = startRetryThrowableError { + throw error + } + startRetryCallsCount += 1 + startRetryReceivedRetry = retry + try await startRetryClosure?(retry) + } + + // MARK: - stop + + var stopCallsCount = 0 + var stopCalled: Bool { + return stopCallsCount > 0 + } + var stopClosure: (() async -> Void)? + + func stop() async { + stopCallsCount += 1 + await stopClosure?() + } + + // MARK: - getSaplingAddress + + var getSaplingAddressAccountIndexThrowableError: Error? + var getSaplingAddressAccountIndexCallsCount = 0 + var getSaplingAddressAccountIndexCalled: Bool { + return getSaplingAddressAccountIndexCallsCount > 0 + } + var getSaplingAddressAccountIndexReceivedAccountIndex: Int? + var getSaplingAddressAccountIndexReturnValue: SaplingAddress! + var getSaplingAddressAccountIndexClosure: ((Int) async throws -> SaplingAddress)? + + func getSaplingAddress(accountIndex: Int) async throws -> SaplingAddress { + if let error = getSaplingAddressAccountIndexThrowableError { + throw error + } + getSaplingAddressAccountIndexCallsCount += 1 + getSaplingAddressAccountIndexReceivedAccountIndex = accountIndex + if let closure = getSaplingAddressAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getSaplingAddressAccountIndexReturnValue + } + } + + // MARK: - getUnifiedAddress + + var getUnifiedAddressAccountIndexThrowableError: Error? + var getUnifiedAddressAccountIndexCallsCount = 0 + var getUnifiedAddressAccountIndexCalled: Bool { + return getUnifiedAddressAccountIndexCallsCount > 0 + } + var getUnifiedAddressAccountIndexReceivedAccountIndex: Int? + var getUnifiedAddressAccountIndexReturnValue: UnifiedAddress! + var getUnifiedAddressAccountIndexClosure: ((Int) async throws -> UnifiedAddress)? + + func getUnifiedAddress(accountIndex: Int) async throws -> UnifiedAddress { + if let error = getUnifiedAddressAccountIndexThrowableError { + throw error + } + getUnifiedAddressAccountIndexCallsCount += 1 + getUnifiedAddressAccountIndexReceivedAccountIndex = accountIndex + if let closure = getUnifiedAddressAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getUnifiedAddressAccountIndexReturnValue + } + } + + // MARK: - getTransparentAddress + + var getTransparentAddressAccountIndexThrowableError: Error? + var getTransparentAddressAccountIndexCallsCount = 0 + var getTransparentAddressAccountIndexCalled: Bool { + return getTransparentAddressAccountIndexCallsCount > 0 + } + var getTransparentAddressAccountIndexReceivedAccountIndex: Int? + var getTransparentAddressAccountIndexReturnValue: TransparentAddress! + var getTransparentAddressAccountIndexClosure: ((Int) async throws -> TransparentAddress)? + + func getTransparentAddress(accountIndex: Int) async throws -> TransparentAddress { + if let error = getTransparentAddressAccountIndexThrowableError { + throw error + } + getTransparentAddressAccountIndexCallsCount += 1 + getTransparentAddressAccountIndexReceivedAccountIndex = accountIndex + if let closure = getTransparentAddressAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getTransparentAddressAccountIndexReturnValue + } + } + + // MARK: - sendToAddress + + var sendToAddressSpendingKeyZatoshiToAddressMemoThrowableError: Error? + var sendToAddressSpendingKeyZatoshiToAddressMemoCallsCount = 0 + var sendToAddressSpendingKeyZatoshiToAddressMemoCalled: Bool { + return sendToAddressSpendingKeyZatoshiToAddressMemoCallsCount > 0 + } + var sendToAddressSpendingKeyZatoshiToAddressMemoReceivedArguments: (spendingKey: UnifiedSpendingKey, zatoshi: Zatoshi, toAddress: Recipient, memo: Memo?)? + var sendToAddressSpendingKeyZatoshiToAddressMemoReturnValue: PendingTransactionEntity! + var sendToAddressSpendingKeyZatoshiToAddressMemoClosure: ((UnifiedSpendingKey, Zatoshi, Recipient, Memo?) async throws -> PendingTransactionEntity)? + + func sendToAddress(spendingKey: UnifiedSpendingKey, zatoshi: Zatoshi, toAddress: Recipient, memo: Memo?) async throws -> PendingTransactionEntity { + if let error = sendToAddressSpendingKeyZatoshiToAddressMemoThrowableError { + throw error + } + sendToAddressSpendingKeyZatoshiToAddressMemoCallsCount += 1 + sendToAddressSpendingKeyZatoshiToAddressMemoReceivedArguments = (spendingKey: spendingKey, zatoshi: zatoshi, toAddress: toAddress, memo: memo) + if let closure = sendToAddressSpendingKeyZatoshiToAddressMemoClosure { + return try await closure(spendingKey, zatoshi, toAddress, memo) + } else { + return sendToAddressSpendingKeyZatoshiToAddressMemoReturnValue + } + } + + // MARK: - shieldFunds + + var shieldFundsSpendingKeyMemoShieldingThresholdThrowableError: Error? + var shieldFundsSpendingKeyMemoShieldingThresholdCallsCount = 0 + var shieldFundsSpendingKeyMemoShieldingThresholdCalled: Bool { + return shieldFundsSpendingKeyMemoShieldingThresholdCallsCount > 0 + } + var shieldFundsSpendingKeyMemoShieldingThresholdReceivedArguments: (spendingKey: UnifiedSpendingKey, memo: Memo, shieldingThreshold: Zatoshi)? + var shieldFundsSpendingKeyMemoShieldingThresholdReturnValue: PendingTransactionEntity! + var shieldFundsSpendingKeyMemoShieldingThresholdClosure: ((UnifiedSpendingKey, Memo, Zatoshi) async throws -> PendingTransactionEntity)? + + func shieldFunds(spendingKey: UnifiedSpendingKey, memo: Memo, shieldingThreshold: Zatoshi) async throws -> PendingTransactionEntity { + if let error = shieldFundsSpendingKeyMemoShieldingThresholdThrowableError { + throw error + } + shieldFundsSpendingKeyMemoShieldingThresholdCallsCount += 1 + shieldFundsSpendingKeyMemoShieldingThresholdReceivedArguments = (spendingKey: spendingKey, memo: memo, shieldingThreshold: shieldingThreshold) + if let closure = shieldFundsSpendingKeyMemoShieldingThresholdClosure { + return try await closure(spendingKey, memo, shieldingThreshold) + } else { + return shieldFundsSpendingKeyMemoShieldingThresholdReturnValue + } + } + + // MARK: - cancelSpend + + var cancelSpendTransactionCallsCount = 0 + var cancelSpendTransactionCalled: Bool { + return cancelSpendTransactionCallsCount > 0 + } + var cancelSpendTransactionReceivedTransaction: PendingTransactionEntity? + var cancelSpendTransactionReturnValue: Bool! + var cancelSpendTransactionClosure: ((PendingTransactionEntity) async -> Bool)? + + func cancelSpend(transaction: PendingTransactionEntity) async -> Bool { + cancelSpendTransactionCallsCount += 1 + cancelSpendTransactionReceivedTransaction = transaction + if let closure = cancelSpendTransactionClosure { + return await closure(transaction) + } else { + return cancelSpendTransactionReturnValue + } + } + + // MARK: - paginatedTransactions + + var paginatedTransactionsOfCallsCount = 0 + var paginatedTransactionsOfCalled: Bool { + return paginatedTransactionsOfCallsCount > 0 + } + var paginatedTransactionsOfReceivedKind: TransactionKind? + var paginatedTransactionsOfReturnValue: PaginatedTransactionRepository! + var paginatedTransactionsOfClosure: ((TransactionKind) -> PaginatedTransactionRepository)? + + func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository { + paginatedTransactionsOfCallsCount += 1 + paginatedTransactionsOfReceivedKind = kind + if let closure = paginatedTransactionsOfClosure { + return closure(kind) + } else { + return paginatedTransactionsOfReturnValue + } + } + + // MARK: - getMemos + + var getMemosForClearedTransactionThrowableError: Error? + var getMemosForClearedTransactionCallsCount = 0 + var getMemosForClearedTransactionCalled: Bool { + return getMemosForClearedTransactionCallsCount > 0 + } + var getMemosForClearedTransactionReceivedTransaction: ZcashTransaction.Overview? + var getMemosForClearedTransactionReturnValue: [Memo]! + var getMemosForClearedTransactionClosure: ((ZcashTransaction.Overview) async throws -> [Memo])? + + func getMemos(for transaction: ZcashTransaction.Overview) async throws -> [Memo] { + if let error = getMemosForClearedTransactionThrowableError { + throw error + } + getMemosForClearedTransactionCallsCount += 1 + getMemosForClearedTransactionReceivedTransaction = transaction + if let closure = getMemosForClearedTransactionClosure { + return try await closure(transaction) + } else { + return getMemosForClearedTransactionReturnValue + } + } + + // MARK: - getMemos + + var getMemosForReceivedTransactionThrowableError: Error? + var getMemosForReceivedTransactionCallsCount = 0 + var getMemosForReceivedTransactionCalled: Bool { + return getMemosForReceivedTransactionCallsCount > 0 + } + var getMemosForReceivedTransactionReceivedReceivedTransaction: ZcashTransaction.Received? + var getMemosForReceivedTransactionReturnValue: [Memo]! + var getMemosForReceivedTransactionClosure: ((ZcashTransaction.Received) async throws -> [Memo])? + + func getMemos(for receivedTransaction: ZcashTransaction.Received) async throws -> [Memo] { + if let error = getMemosForReceivedTransactionThrowableError { + throw error + } + getMemosForReceivedTransactionCallsCount += 1 + getMemosForReceivedTransactionReceivedReceivedTransaction = receivedTransaction + if let closure = getMemosForReceivedTransactionClosure { + return try await closure(receivedTransaction) + } else { + return getMemosForReceivedTransactionReturnValue + } + } + + // MARK: - getMemos + + var getMemosForSentTransactionThrowableError: Error? + var getMemosForSentTransactionCallsCount = 0 + var getMemosForSentTransactionCalled: Bool { + return getMemosForSentTransactionCallsCount > 0 + } + var getMemosForSentTransactionReceivedSentTransaction: ZcashTransaction.Sent? + var getMemosForSentTransactionReturnValue: [Memo]! + var getMemosForSentTransactionClosure: ((ZcashTransaction.Sent) async throws -> [Memo])? + + func getMemos(for sentTransaction: ZcashTransaction.Sent) async throws -> [Memo] { + if let error = getMemosForSentTransactionThrowableError { + throw error + } + getMemosForSentTransactionCallsCount += 1 + getMemosForSentTransactionReceivedSentTransaction = sentTransaction + if let closure = getMemosForSentTransactionClosure { + return try await closure(sentTransaction) + } else { + return getMemosForSentTransactionReturnValue + } + } + + // MARK: - getRecipients + + var getRecipientsForClearedTransactionCallsCount = 0 + var getRecipientsForClearedTransactionCalled: Bool { + return getRecipientsForClearedTransactionCallsCount > 0 + } + var getRecipientsForClearedTransactionReceivedTransaction: ZcashTransaction.Overview? + var getRecipientsForClearedTransactionReturnValue: [TransactionRecipient]! + var getRecipientsForClearedTransactionClosure: ((ZcashTransaction.Overview) async -> [TransactionRecipient])? + + func getRecipients(for transaction: ZcashTransaction.Overview) async -> [TransactionRecipient] { + getRecipientsForClearedTransactionCallsCount += 1 + getRecipientsForClearedTransactionReceivedTransaction = transaction + if let closure = getRecipientsForClearedTransactionClosure { + return await closure(transaction) + } else { + return getRecipientsForClearedTransactionReturnValue + } + } + + // MARK: - getRecipients + + var getRecipientsForSentTransactionCallsCount = 0 + var getRecipientsForSentTransactionCalled: Bool { + return getRecipientsForSentTransactionCallsCount > 0 + } + var getRecipientsForSentTransactionReceivedTransaction: ZcashTransaction.Sent? + var getRecipientsForSentTransactionReturnValue: [TransactionRecipient]! + var getRecipientsForSentTransactionClosure: ((ZcashTransaction.Sent) async -> [TransactionRecipient])? + + func getRecipients(for transaction: ZcashTransaction.Sent) async -> [TransactionRecipient] { + getRecipientsForSentTransactionCallsCount += 1 + getRecipientsForSentTransactionReceivedTransaction = transaction + if let closure = getRecipientsForSentTransactionClosure { + return await closure(transaction) + } else { + return getRecipientsForSentTransactionReturnValue + } + } + + // MARK: - allConfirmedTransactions + + var allConfirmedTransactionsFromLimitThrowableError: Error? + var allConfirmedTransactionsFromLimitCallsCount = 0 + var allConfirmedTransactionsFromLimitCalled: Bool { + return allConfirmedTransactionsFromLimitCallsCount > 0 + } + var allConfirmedTransactionsFromLimitReceivedArguments: (transaction: ZcashTransaction.Overview, limit: Int)? + var allConfirmedTransactionsFromLimitReturnValue: [ZcashTransaction.Overview]! + var allConfirmedTransactionsFromLimitClosure: ((ZcashTransaction.Overview, Int) async throws -> [ZcashTransaction.Overview])? + + func allConfirmedTransactions(from transaction: ZcashTransaction.Overview, limit: Int) async throws -> [ZcashTransaction.Overview] { + if let error = allConfirmedTransactionsFromLimitThrowableError { + throw error + } + allConfirmedTransactionsFromLimitCallsCount += 1 + allConfirmedTransactionsFromLimitReceivedArguments = (transaction: transaction, limit: limit) + if let closure = allConfirmedTransactionsFromLimitClosure { + return try await closure(transaction, limit) + } else { + return allConfirmedTransactionsFromLimitReturnValue + } + } + + // MARK: - latestHeight + + var latestHeightThrowableError: Error? + var latestHeightCallsCount = 0 + var latestHeightCalled: Bool { + return latestHeightCallsCount > 0 + } + var latestHeightReturnValue: BlockHeight! + var latestHeightClosure: (() async throws -> BlockHeight)? + + func latestHeight() async throws -> BlockHeight { + if let error = latestHeightThrowableError { + throw error + } + latestHeightCallsCount += 1 + if let closure = latestHeightClosure { + return try await closure() + } else { + return latestHeightReturnValue + } + } + + // MARK: - refreshUTXOs + + var refreshUTXOsAddressFromThrowableError: Error? + var refreshUTXOsAddressFromCallsCount = 0 + var refreshUTXOsAddressFromCalled: Bool { + return refreshUTXOsAddressFromCallsCount > 0 + } + var refreshUTXOsAddressFromReceivedArguments: (address: TransparentAddress, height: BlockHeight)? + var refreshUTXOsAddressFromReturnValue: RefreshedUTXOs! + var refreshUTXOsAddressFromClosure: ((TransparentAddress, BlockHeight) async throws -> RefreshedUTXOs)? + + func refreshUTXOs(address: TransparentAddress, from height: BlockHeight) async throws -> RefreshedUTXOs { + if let error = refreshUTXOsAddressFromThrowableError { + throw error + } + refreshUTXOsAddressFromCallsCount += 1 + refreshUTXOsAddressFromReceivedArguments = (address: address, height: height) + if let closure = refreshUTXOsAddressFromClosure { + return try await closure(address, height) + } else { + return refreshUTXOsAddressFromReturnValue + } + } + + // MARK: - getTransparentBalance + + var getTransparentBalanceAccountIndexThrowableError: Error? + var getTransparentBalanceAccountIndexCallsCount = 0 + var getTransparentBalanceAccountIndexCalled: Bool { + return getTransparentBalanceAccountIndexCallsCount > 0 + } + var getTransparentBalanceAccountIndexReceivedAccountIndex: Int? + var getTransparentBalanceAccountIndexReturnValue: WalletBalance! + var getTransparentBalanceAccountIndexClosure: ((Int) async throws -> WalletBalance)? + + func getTransparentBalance(accountIndex: Int) async throws -> WalletBalance { + if let error = getTransparentBalanceAccountIndexThrowableError { + throw error + } + getTransparentBalanceAccountIndexCallsCount += 1 + getTransparentBalanceAccountIndexReceivedAccountIndex = accountIndex + if let closure = getTransparentBalanceAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getTransparentBalanceAccountIndexReturnValue + } + } + + // MARK: - getShieldedBalance + + var getShieldedBalanceAccountIndexThrowableError: Error? + var getShieldedBalanceAccountIndexCallsCount = 0 + var getShieldedBalanceAccountIndexCalled: Bool { + return getShieldedBalanceAccountIndexCallsCount > 0 + } + var getShieldedBalanceAccountIndexReceivedAccountIndex: Int? + var getShieldedBalanceAccountIndexReturnValue: Zatoshi! + var getShieldedBalanceAccountIndexClosure: ((Int) async throws -> Zatoshi)? + + func getShieldedBalance(accountIndex: Int) async throws -> Zatoshi { + if let error = getShieldedBalanceAccountIndexThrowableError { + throw error + } + getShieldedBalanceAccountIndexCallsCount += 1 + getShieldedBalanceAccountIndexReceivedAccountIndex = accountIndex + if let closure = getShieldedBalanceAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getShieldedBalanceAccountIndexReturnValue + } + } + + // MARK: - getShieldedVerifiedBalance + + var getShieldedVerifiedBalanceAccountIndexThrowableError: Error? + var getShieldedVerifiedBalanceAccountIndexCallsCount = 0 + var getShieldedVerifiedBalanceAccountIndexCalled: Bool { + return getShieldedVerifiedBalanceAccountIndexCallsCount > 0 + } + var getShieldedVerifiedBalanceAccountIndexReceivedAccountIndex: Int? + var getShieldedVerifiedBalanceAccountIndexReturnValue: Zatoshi! + var getShieldedVerifiedBalanceAccountIndexClosure: ((Int) async throws -> Zatoshi)? + + func getShieldedVerifiedBalance(accountIndex: Int) async throws -> Zatoshi { + if let error = getShieldedVerifiedBalanceAccountIndexThrowableError { + throw error + } + getShieldedVerifiedBalanceAccountIndexCallsCount += 1 + getShieldedVerifiedBalanceAccountIndexReceivedAccountIndex = accountIndex + if let closure = getShieldedVerifiedBalanceAccountIndexClosure { + return try await closure(accountIndex) + } else { + return getShieldedVerifiedBalanceAccountIndexReturnValue + } + } + + // MARK: - rewind + + var rewindCallsCount = 0 + var rewindCalled: Bool { + return rewindCallsCount > 0 + } + var rewindReceivedPolicy: RewindPolicy? + var rewindReturnValue: AnyPublisher! + var rewindClosure: ((RewindPolicy) -> AnyPublisher)? + + func rewind(_ policy: RewindPolicy) -> AnyPublisher { + rewindCallsCount += 1 + rewindReceivedPolicy = policy + if let closure = rewindClosure { + return closure(policy) + } else { + return rewindReturnValue + } + } + + // MARK: - wipe + + var wipeCallsCount = 0 + var wipeCalled: Bool { + return wipeCallsCount > 0 + } + var wipeReturnValue: AnyPublisher! + var wipeClosure: (() -> AnyPublisher)? + + func wipe() -> AnyPublisher { + wipeCallsCount += 1 + if let closure = wipeClosure { + return closure() + } else { + return wipeReturnValue + } + } + +} +actor ZcashRustBackendWeldingMock: ZcashRustBackendWelding { + + nonisolated let consensusBranchIdForHeightClosure: ((Int32) throws -> Int32)? + + init( + consensusBranchIdForHeightClosure: ((Int32) throws -> Int32)? = nil + ) { + self.consensusBranchIdForHeightClosure = consensusBranchIdForHeightClosure + } + + // MARK: - createAccount + + var createAccountSeedThrowableError: Error? + func setCreateAccountSeedThrowableError(_ param: Error?) async { + createAccountSeedThrowableError = param + } + var createAccountSeedCallsCount = 0 + var createAccountSeedCalled: Bool { + return createAccountSeedCallsCount > 0 + } + var createAccountSeedReceivedSeed: [UInt8]? + var createAccountSeedReturnValue: UnifiedSpendingKey! + func setCreateAccountSeedReturnValue(_ param: UnifiedSpendingKey) async { + createAccountSeedReturnValue = param + } + var createAccountSeedClosure: (([UInt8]) async throws -> UnifiedSpendingKey)? + func setCreateAccountSeedClosure(_ param: (([UInt8]) async throws -> UnifiedSpendingKey)?) async { + createAccountSeedClosure = param + } + + func createAccount(seed: [UInt8]) async throws -> UnifiedSpendingKey { + if let error = createAccountSeedThrowableError { + throw error + } + createAccountSeedCallsCount += 1 + createAccountSeedReceivedSeed = seed + if let closure = createAccountSeedClosure { + return try await closure(seed) + } else { + return createAccountSeedReturnValue + } + } + + // MARK: - createToAddress + + var createToAddressUskToValueMemoThrowableError: Error? + func setCreateToAddressUskToValueMemoThrowableError(_ param: Error?) async { + createToAddressUskToValueMemoThrowableError = param + } + var createToAddressUskToValueMemoCallsCount = 0 + var createToAddressUskToValueMemoCalled: Bool { + return createToAddressUskToValueMemoCallsCount > 0 + } + var createToAddressUskToValueMemoReceivedArguments: (usk: UnifiedSpendingKey, address: String, value: Int64, memo: MemoBytes?)? + var createToAddressUskToValueMemoReturnValue: Int64! + func setCreateToAddressUskToValueMemoReturnValue(_ param: Int64) async { + createToAddressUskToValueMemoReturnValue = param + } + var createToAddressUskToValueMemoClosure: ((UnifiedSpendingKey, String, Int64, MemoBytes?) async throws -> Int64)? + func setCreateToAddressUskToValueMemoClosure(_ param: ((UnifiedSpendingKey, String, Int64, MemoBytes?) async throws -> Int64)?) async { + createToAddressUskToValueMemoClosure = param + } + + func createToAddress(usk: UnifiedSpendingKey, to address: String, value: Int64, memo: MemoBytes?) async throws -> Int64 { + if let error = createToAddressUskToValueMemoThrowableError { + throw error + } + createToAddressUskToValueMemoCallsCount += 1 + createToAddressUskToValueMemoReceivedArguments = (usk: usk, address: address, value: value, memo: memo) + if let closure = createToAddressUskToValueMemoClosure { + return try await closure(usk, address, value, memo) + } else { + return createToAddressUskToValueMemoReturnValue + } + } + + // MARK: - decryptAndStoreTransaction + + var decryptAndStoreTransactionTxBytesMinedHeightThrowableError: Error? + func setDecryptAndStoreTransactionTxBytesMinedHeightThrowableError(_ param: Error?) async { + decryptAndStoreTransactionTxBytesMinedHeightThrowableError = param + } + var decryptAndStoreTransactionTxBytesMinedHeightCallsCount = 0 + var decryptAndStoreTransactionTxBytesMinedHeightCalled: Bool { + return decryptAndStoreTransactionTxBytesMinedHeightCallsCount > 0 + } + var decryptAndStoreTransactionTxBytesMinedHeightReceivedArguments: (txBytes: [UInt8], minedHeight: Int32)? + var decryptAndStoreTransactionTxBytesMinedHeightClosure: (([UInt8], Int32) async throws -> Void)? + func setDecryptAndStoreTransactionTxBytesMinedHeightClosure(_ param: (([UInt8], Int32) async throws -> Void)?) async { + decryptAndStoreTransactionTxBytesMinedHeightClosure = param + } + + func decryptAndStoreTransaction(txBytes: [UInt8], minedHeight: Int32) async throws { + if let error = decryptAndStoreTransactionTxBytesMinedHeightThrowableError { + throw error + } + decryptAndStoreTransactionTxBytesMinedHeightCallsCount += 1 + decryptAndStoreTransactionTxBytesMinedHeightReceivedArguments = (txBytes: txBytes, minedHeight: minedHeight) + try await decryptAndStoreTransactionTxBytesMinedHeightClosure?(txBytes, minedHeight) + } + + // MARK: - getBalance + + var getBalanceAccountThrowableError: Error? + func setGetBalanceAccountThrowableError(_ param: Error?) async { + getBalanceAccountThrowableError = param + } + var getBalanceAccountCallsCount = 0 + var getBalanceAccountCalled: Bool { + return getBalanceAccountCallsCount > 0 + } + var getBalanceAccountReceivedAccount: Int32? + var getBalanceAccountReturnValue: Int64! + func setGetBalanceAccountReturnValue(_ param: Int64) async { + getBalanceAccountReturnValue = param + } + var getBalanceAccountClosure: ((Int32) async throws -> Int64)? + func setGetBalanceAccountClosure(_ param: ((Int32) async throws -> Int64)?) async { + getBalanceAccountClosure = param + } + + func getBalance(account: Int32) async throws -> Int64 { + if let error = getBalanceAccountThrowableError { + throw error + } + getBalanceAccountCallsCount += 1 + getBalanceAccountReceivedAccount = account + if let closure = getBalanceAccountClosure { + return try await closure(account) + } else { + return getBalanceAccountReturnValue + } + } + + // MARK: - getCurrentAddress + + var getCurrentAddressAccountThrowableError: Error? + func setGetCurrentAddressAccountThrowableError(_ param: Error?) async { + getCurrentAddressAccountThrowableError = param + } + var getCurrentAddressAccountCallsCount = 0 + var getCurrentAddressAccountCalled: Bool { + return getCurrentAddressAccountCallsCount > 0 + } + var getCurrentAddressAccountReceivedAccount: Int32? + var getCurrentAddressAccountReturnValue: UnifiedAddress! + func setGetCurrentAddressAccountReturnValue(_ param: UnifiedAddress) async { + getCurrentAddressAccountReturnValue = param + } + var getCurrentAddressAccountClosure: ((Int32) async throws -> UnifiedAddress)? + func setGetCurrentAddressAccountClosure(_ param: ((Int32) async throws -> UnifiedAddress)?) async { + getCurrentAddressAccountClosure = param + } + + func getCurrentAddress(account: Int32) async throws -> UnifiedAddress { + if let error = getCurrentAddressAccountThrowableError { + throw error + } + getCurrentAddressAccountCallsCount += 1 + getCurrentAddressAccountReceivedAccount = account + if let closure = getCurrentAddressAccountClosure { + return try await closure(account) + } else { + return getCurrentAddressAccountReturnValue + } + } + + // MARK: - getNearestRewindHeight + + var getNearestRewindHeightHeightThrowableError: Error? + func setGetNearestRewindHeightHeightThrowableError(_ param: Error?) async { + getNearestRewindHeightHeightThrowableError = param + } + var getNearestRewindHeightHeightCallsCount = 0 + var getNearestRewindHeightHeightCalled: Bool { + return getNearestRewindHeightHeightCallsCount > 0 + } + var getNearestRewindHeightHeightReceivedHeight: Int32? + var getNearestRewindHeightHeightReturnValue: Int32! + func setGetNearestRewindHeightHeightReturnValue(_ param: Int32) async { + getNearestRewindHeightHeightReturnValue = param + } + var getNearestRewindHeightHeightClosure: ((Int32) async throws -> Int32)? + func setGetNearestRewindHeightHeightClosure(_ param: ((Int32) async throws -> Int32)?) async { + getNearestRewindHeightHeightClosure = param + } + + func getNearestRewindHeight(height: Int32) async throws -> Int32 { + if let error = getNearestRewindHeightHeightThrowableError { + throw error + } + getNearestRewindHeightHeightCallsCount += 1 + getNearestRewindHeightHeightReceivedHeight = height + if let closure = getNearestRewindHeightHeightClosure { + return try await closure(height) + } else { + return getNearestRewindHeightHeightReturnValue + } + } + + // MARK: - getNextAvailableAddress + + var getNextAvailableAddressAccountThrowableError: Error? + func setGetNextAvailableAddressAccountThrowableError(_ param: Error?) async { + getNextAvailableAddressAccountThrowableError = param + } + var getNextAvailableAddressAccountCallsCount = 0 + var getNextAvailableAddressAccountCalled: Bool { + return getNextAvailableAddressAccountCallsCount > 0 + } + var getNextAvailableAddressAccountReceivedAccount: Int32? + var getNextAvailableAddressAccountReturnValue: UnifiedAddress! + func setGetNextAvailableAddressAccountReturnValue(_ param: UnifiedAddress) async { + getNextAvailableAddressAccountReturnValue = param + } + var getNextAvailableAddressAccountClosure: ((Int32) async throws -> UnifiedAddress)? + func setGetNextAvailableAddressAccountClosure(_ param: ((Int32) async throws -> UnifiedAddress)?) async { + getNextAvailableAddressAccountClosure = param + } + + func getNextAvailableAddress(account: Int32) async throws -> UnifiedAddress { + if let error = getNextAvailableAddressAccountThrowableError { + throw error + } + getNextAvailableAddressAccountCallsCount += 1 + getNextAvailableAddressAccountReceivedAccount = account + if let closure = getNextAvailableAddressAccountClosure { + return try await closure(account) + } else { + return getNextAvailableAddressAccountReturnValue + } + } + + // MARK: - getReceivedMemo + + var getReceivedMemoIdNoteCallsCount = 0 + var getReceivedMemoIdNoteCalled: Bool { + return getReceivedMemoIdNoteCallsCount > 0 + } + var getReceivedMemoIdNoteReceivedIdNote: Int64? + var getReceivedMemoIdNoteReturnValue: Memo? + func setGetReceivedMemoIdNoteReturnValue(_ param: Memo?) async { + getReceivedMemoIdNoteReturnValue = param + } + var getReceivedMemoIdNoteClosure: ((Int64) async -> Memo?)? + func setGetReceivedMemoIdNoteClosure(_ param: ((Int64) async -> Memo?)?) async { + getReceivedMemoIdNoteClosure = param + } + + func getReceivedMemo(idNote: Int64) async -> Memo? { + getReceivedMemoIdNoteCallsCount += 1 + getReceivedMemoIdNoteReceivedIdNote = idNote + if let closure = getReceivedMemoIdNoteClosure { + return await closure(idNote) + } else { + return getReceivedMemoIdNoteReturnValue + } + } + + // MARK: - getSentMemo + + var getSentMemoIdNoteCallsCount = 0 + var getSentMemoIdNoteCalled: Bool { + return getSentMemoIdNoteCallsCount > 0 + } + var getSentMemoIdNoteReceivedIdNote: Int64? + var getSentMemoIdNoteReturnValue: Memo? + func setGetSentMemoIdNoteReturnValue(_ param: Memo?) async { + getSentMemoIdNoteReturnValue = param + } + var getSentMemoIdNoteClosure: ((Int64) async -> Memo?)? + func setGetSentMemoIdNoteClosure(_ param: ((Int64) async -> Memo?)?) async { + getSentMemoIdNoteClosure = param + } + + func getSentMemo(idNote: Int64) async -> Memo? { + getSentMemoIdNoteCallsCount += 1 + getSentMemoIdNoteReceivedIdNote = idNote + if let closure = getSentMemoIdNoteClosure { + return await closure(idNote) + } else { + return getSentMemoIdNoteReturnValue + } + } + + // MARK: - getTransparentBalance + + var getTransparentBalanceAccountThrowableError: Error? + func setGetTransparentBalanceAccountThrowableError(_ param: Error?) async { + getTransparentBalanceAccountThrowableError = param + } + var getTransparentBalanceAccountCallsCount = 0 + var getTransparentBalanceAccountCalled: Bool { + return getTransparentBalanceAccountCallsCount > 0 + } + var getTransparentBalanceAccountReceivedAccount: Int32? + var getTransparentBalanceAccountReturnValue: Int64! + func setGetTransparentBalanceAccountReturnValue(_ param: Int64) async { + getTransparentBalanceAccountReturnValue = param + } + var getTransparentBalanceAccountClosure: ((Int32) async throws -> Int64)? + func setGetTransparentBalanceAccountClosure(_ param: ((Int32) async throws -> Int64)?) async { + getTransparentBalanceAccountClosure = param + } + + func getTransparentBalance(account: Int32) async throws -> Int64 { + if let error = getTransparentBalanceAccountThrowableError { + throw error + } + getTransparentBalanceAccountCallsCount += 1 + getTransparentBalanceAccountReceivedAccount = account + if let closure = getTransparentBalanceAccountClosure { + return try await closure(account) + } else { + return getTransparentBalanceAccountReturnValue + } + } + + // MARK: - initAccountsTable + + var initAccountsTableUfvksThrowableError: Error? + func setInitAccountsTableUfvksThrowableError(_ param: Error?) async { + initAccountsTableUfvksThrowableError = param + } + var initAccountsTableUfvksCallsCount = 0 + var initAccountsTableUfvksCalled: Bool { + return initAccountsTableUfvksCallsCount > 0 + } + var initAccountsTableUfvksReceivedUfvks: [UnifiedFullViewingKey]? + var initAccountsTableUfvksClosure: (([UnifiedFullViewingKey]) async throws -> Void)? + func setInitAccountsTableUfvksClosure(_ param: (([UnifiedFullViewingKey]) async throws -> Void)?) async { + initAccountsTableUfvksClosure = param + } + + func initAccountsTable(ufvks: [UnifiedFullViewingKey]) async throws { + if let error = initAccountsTableUfvksThrowableError { + throw error + } + initAccountsTableUfvksCallsCount += 1 + initAccountsTableUfvksReceivedUfvks = ufvks + try await initAccountsTableUfvksClosure?(ufvks) + } + + // MARK: - initDataDb + + var initDataDbSeedThrowableError: Error? + func setInitDataDbSeedThrowableError(_ param: Error?) async { + initDataDbSeedThrowableError = param + } + var initDataDbSeedCallsCount = 0 + var initDataDbSeedCalled: Bool { + return initDataDbSeedCallsCount > 0 + } + var initDataDbSeedReceivedSeed: [UInt8]? + var initDataDbSeedReturnValue: DbInitResult! + func setInitDataDbSeedReturnValue(_ param: DbInitResult) async { + initDataDbSeedReturnValue = param + } + var initDataDbSeedClosure: (([UInt8]?) async throws -> DbInitResult)? + func setInitDataDbSeedClosure(_ param: (([UInt8]?) async throws -> DbInitResult)?) async { + initDataDbSeedClosure = param + } + + func initDataDb(seed: [UInt8]?) async throws -> DbInitResult { + if let error = initDataDbSeedThrowableError { + throw error + } + initDataDbSeedCallsCount += 1 + initDataDbSeedReceivedSeed = seed + if let closure = initDataDbSeedClosure { + return try await closure(seed) + } else { + return initDataDbSeedReturnValue + } + } + + // MARK: - initBlocksTable + + var initBlocksTableHeightHashTimeSaplingTreeThrowableError: Error? + func setInitBlocksTableHeightHashTimeSaplingTreeThrowableError(_ param: Error?) async { + initBlocksTableHeightHashTimeSaplingTreeThrowableError = param + } + var initBlocksTableHeightHashTimeSaplingTreeCallsCount = 0 + var initBlocksTableHeightHashTimeSaplingTreeCalled: Bool { + return initBlocksTableHeightHashTimeSaplingTreeCallsCount > 0 + } + var initBlocksTableHeightHashTimeSaplingTreeReceivedArguments: (height: Int32, hash: String, time: UInt32, saplingTree: String)? + var initBlocksTableHeightHashTimeSaplingTreeClosure: ((Int32, String, UInt32, String) async throws -> Void)? + func setInitBlocksTableHeightHashTimeSaplingTreeClosure(_ param: ((Int32, String, UInt32, String) async throws -> Void)?) async { + initBlocksTableHeightHashTimeSaplingTreeClosure = param + } + + func initBlocksTable(height: Int32, hash: String, time: UInt32, saplingTree: String) async throws { + if let error = initBlocksTableHeightHashTimeSaplingTreeThrowableError { + throw error + } + initBlocksTableHeightHashTimeSaplingTreeCallsCount += 1 + initBlocksTableHeightHashTimeSaplingTreeReceivedArguments = (height: height, hash: hash, time: time, saplingTree: saplingTree) + try await initBlocksTableHeightHashTimeSaplingTreeClosure?(height, hash, time, saplingTree) + } + + // MARK: - listTransparentReceivers + + var listTransparentReceiversAccountThrowableError: Error? + func setListTransparentReceiversAccountThrowableError(_ param: Error?) async { + listTransparentReceiversAccountThrowableError = param + } + var listTransparentReceiversAccountCallsCount = 0 + var listTransparentReceiversAccountCalled: Bool { + return listTransparentReceiversAccountCallsCount > 0 + } + var listTransparentReceiversAccountReceivedAccount: Int32? + var listTransparentReceiversAccountReturnValue: [TransparentAddress]! + func setListTransparentReceiversAccountReturnValue(_ param: [TransparentAddress]) async { + listTransparentReceiversAccountReturnValue = param + } + var listTransparentReceiversAccountClosure: ((Int32) async throws -> [TransparentAddress])? + func setListTransparentReceiversAccountClosure(_ param: ((Int32) async throws -> [TransparentAddress])?) async { + listTransparentReceiversAccountClosure = param + } + + func listTransparentReceivers(account: Int32) async throws -> [TransparentAddress] { + if let error = listTransparentReceiversAccountThrowableError { + throw error + } + listTransparentReceiversAccountCallsCount += 1 + listTransparentReceiversAccountReceivedAccount = account + if let closure = listTransparentReceiversAccountClosure { + return try await closure(account) + } else { + return listTransparentReceiversAccountReturnValue + } + } + + // MARK: - getVerifiedBalance + + var getVerifiedBalanceAccountThrowableError: Error? + func setGetVerifiedBalanceAccountThrowableError(_ param: Error?) async { + getVerifiedBalanceAccountThrowableError = param + } + var getVerifiedBalanceAccountCallsCount = 0 + var getVerifiedBalanceAccountCalled: Bool { + return getVerifiedBalanceAccountCallsCount > 0 + } + var getVerifiedBalanceAccountReceivedAccount: Int32? + var getVerifiedBalanceAccountReturnValue: Int64! + func setGetVerifiedBalanceAccountReturnValue(_ param: Int64) async { + getVerifiedBalanceAccountReturnValue = param + } + var getVerifiedBalanceAccountClosure: ((Int32) async throws -> Int64)? + func setGetVerifiedBalanceAccountClosure(_ param: ((Int32) async throws -> Int64)?) async { + getVerifiedBalanceAccountClosure = param + } + + func getVerifiedBalance(account: Int32) async throws -> Int64 { + if let error = getVerifiedBalanceAccountThrowableError { + throw error + } + getVerifiedBalanceAccountCallsCount += 1 + getVerifiedBalanceAccountReceivedAccount = account + if let closure = getVerifiedBalanceAccountClosure { + return try await closure(account) + } else { + return getVerifiedBalanceAccountReturnValue + } + } + + // MARK: - getVerifiedTransparentBalance + + var getVerifiedTransparentBalanceAccountThrowableError: Error? + func setGetVerifiedTransparentBalanceAccountThrowableError(_ param: Error?) async { + getVerifiedTransparentBalanceAccountThrowableError = param + } + var getVerifiedTransparentBalanceAccountCallsCount = 0 + var getVerifiedTransparentBalanceAccountCalled: Bool { + return getVerifiedTransparentBalanceAccountCallsCount > 0 + } + var getVerifiedTransparentBalanceAccountReceivedAccount: Int32? + var getVerifiedTransparentBalanceAccountReturnValue: Int64! + func setGetVerifiedTransparentBalanceAccountReturnValue(_ param: Int64) async { + getVerifiedTransparentBalanceAccountReturnValue = param + } + var getVerifiedTransparentBalanceAccountClosure: ((Int32) async throws -> Int64)? + func setGetVerifiedTransparentBalanceAccountClosure(_ param: ((Int32) async throws -> Int64)?) async { + getVerifiedTransparentBalanceAccountClosure = param + } + + func getVerifiedTransparentBalance(account: Int32) async throws -> Int64 { + if let error = getVerifiedTransparentBalanceAccountThrowableError { + throw error + } + getVerifiedTransparentBalanceAccountCallsCount += 1 + getVerifiedTransparentBalanceAccountReceivedAccount = account + if let closure = getVerifiedTransparentBalanceAccountClosure { + return try await closure(account) + } else { + return getVerifiedTransparentBalanceAccountReturnValue + } + } + + // MARK: - validateCombinedChain + + var validateCombinedChainLimitThrowableError: Error? + func setValidateCombinedChainLimitThrowableError(_ param: Error?) async { + validateCombinedChainLimitThrowableError = param + } + var validateCombinedChainLimitCallsCount = 0 + var validateCombinedChainLimitCalled: Bool { + return validateCombinedChainLimitCallsCount > 0 + } + var validateCombinedChainLimitReceivedLimit: UInt32? + var validateCombinedChainLimitClosure: ((UInt32) async throws -> Void)? + func setValidateCombinedChainLimitClosure(_ param: ((UInt32) async throws -> Void)?) async { + validateCombinedChainLimitClosure = param + } + + func validateCombinedChain(limit: UInt32) async throws { + if let error = validateCombinedChainLimitThrowableError { + throw error + } + validateCombinedChainLimitCallsCount += 1 + validateCombinedChainLimitReceivedLimit = limit + try await validateCombinedChainLimitClosure?(limit) + } + + // MARK: - rewindToHeight + + var rewindToHeightHeightThrowableError: Error? + func setRewindToHeightHeightThrowableError(_ param: Error?) async { + rewindToHeightHeightThrowableError = param + } + var rewindToHeightHeightCallsCount = 0 + var rewindToHeightHeightCalled: Bool { + return rewindToHeightHeightCallsCount > 0 + } + var rewindToHeightHeightReceivedHeight: Int32? + var rewindToHeightHeightClosure: ((Int32) async throws -> Void)? + func setRewindToHeightHeightClosure(_ param: ((Int32) async throws -> Void)?) async { + rewindToHeightHeightClosure = param + } + + func rewindToHeight(height: Int32) async throws { + if let error = rewindToHeightHeightThrowableError { + throw error + } + rewindToHeightHeightCallsCount += 1 + rewindToHeightHeightReceivedHeight = height + try await rewindToHeightHeightClosure?(height) + } + + // MARK: - rewindCacheToHeight + + var rewindCacheToHeightHeightThrowableError: Error? + func setRewindCacheToHeightHeightThrowableError(_ param: Error?) async { + rewindCacheToHeightHeightThrowableError = param + } + var rewindCacheToHeightHeightCallsCount = 0 + var rewindCacheToHeightHeightCalled: Bool { + return rewindCacheToHeightHeightCallsCount > 0 + } + var rewindCacheToHeightHeightReceivedHeight: Int32? + var rewindCacheToHeightHeightClosure: ((Int32) async throws -> Void)? + func setRewindCacheToHeightHeightClosure(_ param: ((Int32) async throws -> Void)?) async { + rewindCacheToHeightHeightClosure = param + } + + func rewindCacheToHeight(height: Int32) async throws { + if let error = rewindCacheToHeightHeightThrowableError { + throw error + } + rewindCacheToHeightHeightCallsCount += 1 + rewindCacheToHeightHeightReceivedHeight = height + try await rewindCacheToHeightHeightClosure?(height) + } + + // MARK: - scanBlocks + + var scanBlocksLimitThrowableError: Error? + func setScanBlocksLimitThrowableError(_ param: Error?) async { + scanBlocksLimitThrowableError = param + } + var scanBlocksLimitCallsCount = 0 + var scanBlocksLimitCalled: Bool { + return scanBlocksLimitCallsCount > 0 + } + var scanBlocksLimitReceivedLimit: UInt32? + var scanBlocksLimitClosure: ((UInt32) async throws -> Void)? + func setScanBlocksLimitClosure(_ param: ((UInt32) async throws -> Void)?) async { + scanBlocksLimitClosure = param + } + + func scanBlocks(limit: UInt32) async throws { + if let error = scanBlocksLimitThrowableError { + throw error + } + scanBlocksLimitCallsCount += 1 + scanBlocksLimitReceivedLimit = limit + try await scanBlocksLimitClosure?(limit) + } + + // MARK: - putUnspentTransparentOutput + + var putUnspentTransparentOutputTxidIndexScriptValueHeightThrowableError: Error? + func setPutUnspentTransparentOutputTxidIndexScriptValueHeightThrowableError(_ param: Error?) async { + putUnspentTransparentOutputTxidIndexScriptValueHeightThrowableError = param + } + var putUnspentTransparentOutputTxidIndexScriptValueHeightCallsCount = 0 + var putUnspentTransparentOutputTxidIndexScriptValueHeightCalled: Bool { + return putUnspentTransparentOutputTxidIndexScriptValueHeightCallsCount > 0 + } + var putUnspentTransparentOutputTxidIndexScriptValueHeightReceivedArguments: (txid: [UInt8], index: Int, script: [UInt8], value: Int64, height: BlockHeight)? + var putUnspentTransparentOutputTxidIndexScriptValueHeightClosure: (([UInt8], Int, [UInt8], Int64, BlockHeight) async throws -> Void)? + func setPutUnspentTransparentOutputTxidIndexScriptValueHeightClosure(_ param: (([UInt8], Int, [UInt8], Int64, BlockHeight) async throws -> Void)?) async { + putUnspentTransparentOutputTxidIndexScriptValueHeightClosure = param + } + + func putUnspentTransparentOutput(txid: [UInt8], index: Int, script: [UInt8], value: Int64, height: BlockHeight) async throws { + if let error = putUnspentTransparentOutputTxidIndexScriptValueHeightThrowableError { + throw error + } + putUnspentTransparentOutputTxidIndexScriptValueHeightCallsCount += 1 + putUnspentTransparentOutputTxidIndexScriptValueHeightReceivedArguments = (txid: txid, index: index, script: script, value: value, height: height) + try await putUnspentTransparentOutputTxidIndexScriptValueHeightClosure?(txid, index, script, value, height) + } + + // MARK: - shieldFunds + + var shieldFundsUskMemoShieldingThresholdThrowableError: Error? + func setShieldFundsUskMemoShieldingThresholdThrowableError(_ param: Error?) async { + shieldFundsUskMemoShieldingThresholdThrowableError = param + } + var shieldFundsUskMemoShieldingThresholdCallsCount = 0 + var shieldFundsUskMemoShieldingThresholdCalled: Bool { + return shieldFundsUskMemoShieldingThresholdCallsCount > 0 + } + var shieldFundsUskMemoShieldingThresholdReceivedArguments: (usk: UnifiedSpendingKey, memo: MemoBytes?, shieldingThreshold: Zatoshi)? + var shieldFundsUskMemoShieldingThresholdReturnValue: Int64! + func setShieldFundsUskMemoShieldingThresholdReturnValue(_ param: Int64) async { + shieldFundsUskMemoShieldingThresholdReturnValue = param + } + var shieldFundsUskMemoShieldingThresholdClosure: ((UnifiedSpendingKey, MemoBytes?, Zatoshi) async throws -> Int64)? + func setShieldFundsUskMemoShieldingThresholdClosure(_ param: ((UnifiedSpendingKey, MemoBytes?, Zatoshi) async throws -> Int64)?) async { + shieldFundsUskMemoShieldingThresholdClosure = param + } + + func shieldFunds(usk: UnifiedSpendingKey, memo: MemoBytes?, shieldingThreshold: Zatoshi) async throws -> Int64 { + if let error = shieldFundsUskMemoShieldingThresholdThrowableError { + throw error + } + shieldFundsUskMemoShieldingThresholdCallsCount += 1 + shieldFundsUskMemoShieldingThresholdReceivedArguments = (usk: usk, memo: memo, shieldingThreshold: shieldingThreshold) + if let closure = shieldFundsUskMemoShieldingThresholdClosure { + return try await closure(usk, memo, shieldingThreshold) + } else { + return shieldFundsUskMemoShieldingThresholdReturnValue + } + } + + // MARK: - consensusBranchIdFor + + + nonisolated func consensusBranchIdFor(height: Int32) throws -> Int32 { + try consensusBranchIdForHeightClosure!(height) + } + + // MARK: - initBlockMetadataDb + + var initBlockMetadataDbThrowableError: Error? + func setInitBlockMetadataDbThrowableError(_ param: Error?) async { + initBlockMetadataDbThrowableError = param + } + var initBlockMetadataDbCallsCount = 0 + var initBlockMetadataDbCalled: Bool { + return initBlockMetadataDbCallsCount > 0 + } + var initBlockMetadataDbClosure: (() async throws -> Void)? + func setInitBlockMetadataDbClosure(_ param: (() async throws -> Void)?) async { + initBlockMetadataDbClosure = param + } + + func initBlockMetadataDb() async throws { + if let error = initBlockMetadataDbThrowableError { + throw error + } + initBlockMetadataDbCallsCount += 1 + try await initBlockMetadataDbClosure?() + } + + // MARK: - writeBlocksMetadata + + var writeBlocksMetadataBlocksThrowableError: Error? + func setWriteBlocksMetadataBlocksThrowableError(_ param: Error?) async { + writeBlocksMetadataBlocksThrowableError = param + } + var writeBlocksMetadataBlocksCallsCount = 0 + var writeBlocksMetadataBlocksCalled: Bool { + return writeBlocksMetadataBlocksCallsCount > 0 + } + var writeBlocksMetadataBlocksReceivedBlocks: [ZcashCompactBlock]? + var writeBlocksMetadataBlocksClosure: (([ZcashCompactBlock]) async throws -> Void)? + func setWriteBlocksMetadataBlocksClosure(_ param: (([ZcashCompactBlock]) async throws -> Void)?) async { + writeBlocksMetadataBlocksClosure = param + } + + func writeBlocksMetadata(blocks: [ZcashCompactBlock]) async throws { + if let error = writeBlocksMetadataBlocksThrowableError { + throw error + } + writeBlocksMetadataBlocksCallsCount += 1 + writeBlocksMetadataBlocksReceivedBlocks = blocks + try await writeBlocksMetadataBlocksClosure?(blocks) + } + + // MARK: - latestCachedBlockHeight + + var latestCachedBlockHeightCallsCount = 0 + var latestCachedBlockHeightCalled: Bool { + return latestCachedBlockHeightCallsCount > 0 + } + var latestCachedBlockHeightReturnValue: BlockHeight! + func setLatestCachedBlockHeightReturnValue(_ param: BlockHeight) async { + latestCachedBlockHeightReturnValue = param + } + var latestCachedBlockHeightClosure: (() async -> BlockHeight)? + func setLatestCachedBlockHeightClosure(_ param: (() async -> BlockHeight)?) async { + latestCachedBlockHeightClosure = param + } + + func latestCachedBlockHeight() async -> BlockHeight { + latestCachedBlockHeightCallsCount += 1 + if let closure = latestCachedBlockHeightClosure { + return await closure() + } else { + return latestCachedBlockHeightReturnValue + } + } + +} diff --git a/Tests/TestUtils/Sourcery/generateMocks.sh b/Tests/TestUtils/Sourcery/generateMocks.sh new file mode 100755 index 00000000..16bde030 --- /dev/null +++ b/Tests/TestUtils/Sourcery/generateMocks.sh @@ -0,0 +1,22 @@ +#!/bin/zsh + +sourcery_version=2.0.2 + +if which sourcery >/dev/null; then + if [[ $(sourcery --version) != $sourcery_version ]]; then + echo "warning: Compatible sourcery version not installed. Install sourcer $sourcery_version. Currently installed version is $(sourcery --version)" + exit 1 + fi + + sourcery \ + --sources ./ \ + --sources ../../../Sources/ \ + --templates AutoMockable.stencil \ + --output GeneratedMocks/ + +else + echo "warning: sourcery not installed" +fi + + + diff --git a/Tests/TestUtils/Stubs.swift b/Tests/TestUtils/Stubs.swift index 14be95d5..3b32e519 100644 --- a/Tests/TestUtils/Stubs.swift +++ b/Tests/TestUtils/Stubs.swift @@ -49,342 +49,128 @@ extension LightWalletServiceMockResponse { } } -class MockRustBackend: ZcashRustBackendWelding { - static var networkType = NetworkType.testnet - static var mockDataDb = false - static var mockAcounts = false - static var mockError: RustWeldingError? - static var mockLastError: String? - static var mockAccounts: [SaplingExtendedSpendingKey]? - static var mockAddresses: [String]? - static var mockBalance: Int64? - static var mockVerifiedBalance: Int64? - static var mockMemo: String? - static var mockSentMemo: String? - static var mockValidateCombinedChainSuccessRate: Float? - static var mockValidateCombinedChainFailAfterAttempts: Int? - static var mockValidateCombinedChainKeepFailing = false - static var mockValidateCombinedChainFailureHeight: BlockHeight = 0 - static var mockScanblocksSuccessRate: Float? - static var mockCreateToAddress: Int64? - static var rustBackend = ZcashRustBackend.self - static var consensusBranchID: Int32? - static var writeBlocksMetadataResult: () throws -> Bool = { true } - static var rewindCacheToHeightResult: () -> Bool = { true } - static func latestCachedBlockHeight(fsBlockDbRoot: URL) async -> ZcashLightClientKit.BlockHeight { - .empty() +class RustBackendMockHelper { + let rustBackendMock: ZcashRustBackendWeldingMock + var mockValidateCombinedChainFailAfterAttempts: Int? + + init( + rustBackend: ZcashRustBackendWelding, + consensusBranchID: Int32? = nil, + mockValidateCombinedChainSuccessRate: Float? = nil, + mockValidateCombinedChainFailAfterAttempts: Int? = nil, + mockValidateCombinedChainKeepFailing: Bool = false, + mockValidateCombinedChainFailureError: RustWeldingError = .chainValidationFailed(message: nil) + ) async { + self.mockValidateCombinedChainFailAfterAttempts = mockValidateCombinedChainFailAfterAttempts + self.rustBackendMock = ZcashRustBackendWeldingMock( + consensusBranchIdForHeightClosure: { height in + if let consensusBranchID { + return consensusBranchID + } else { + return try rustBackend.consensusBranchIdFor(height: height) + } + } + ) + await setupDefaultMock( + rustBackend: rustBackend, + mockValidateCombinedChainSuccessRate: mockValidateCombinedChainSuccessRate, + mockValidateCombinedChainKeepFailing: mockValidateCombinedChainKeepFailing, + mockValidateCombinedChainFailureError: mockValidateCombinedChainFailureError + ) } - static func rewindCacheToHeight(fsBlockDbRoot: URL, height: Int32) async -> Bool { - rewindCacheToHeightResult() - } + private func setupDefaultMock( + rustBackend: ZcashRustBackendWelding, + mockValidateCombinedChainSuccessRate: Float? = nil, + mockValidateCombinedChainKeepFailing: Bool = false, + mockValidateCombinedChainFailureError: RustWeldingError = .chainValidationFailed(message: nil) + ) async { + await rustBackendMock.setLatestCachedBlockHeightReturnValue(.empty()) + await rustBackendMock.setInitBlockMetadataDbClosure() { } + await rustBackendMock.setWriteBlocksMetadataBlocksClosure() { _ in } + await rustBackendMock.setInitAccountsTableUfvksClosure() { _ in } + await rustBackendMock.setCreateToAddressUskToValueMemoReturnValue(-1) + await rustBackendMock.setShieldFundsUskMemoShieldingThresholdReturnValue(-1) + await rustBackendMock.setGetTransparentBalanceAccountReturnValue(0) + await rustBackendMock.setGetVerifiedBalanceAccountReturnValue(0) + await rustBackendMock.setListTransparentReceiversAccountReturnValue([]) + await rustBackendMock.setGetCurrentAddressAccountThrowableError(KeyDerivationErrors.unableToDerive) + await rustBackendMock.setGetNextAvailableAddressAccountThrowableError(KeyDerivationErrors.unableToDerive) + await rustBackendMock.setShieldFundsUskMemoShieldingThresholdReturnValue(-1) + await rustBackendMock.setCreateAccountSeedThrowableError(KeyDerivationErrors.unableToDerive) + await rustBackendMock.setGetReceivedMemoIdNoteReturnValue(nil) + await rustBackendMock.setGetSentMemoIdNoteReturnValue(nil) + await rustBackendMock.setCreateToAddressUskToValueMemoReturnValue(-1) + await rustBackendMock.setInitDataDbSeedReturnValue(.seedRequired) + await rustBackendMock.setGetNearestRewindHeightHeightReturnValue(-1) + await rustBackendMock.setInitBlocksTableHeightHashTimeSaplingTreeClosure() { _, _, _, _ in } + await rustBackendMock.setPutUnspentTransparentOutputTxidIndexScriptValueHeightClosure() { _, _, _, _, _ in } + await rustBackendMock.setCreateToAddressUskToValueMemoReturnValue(-1) + await rustBackendMock.setCreateToAddressUskToValueMemoReturnValue(-1) + await rustBackendMock.setDecryptAndStoreTransactionTxBytesMinedHeightThrowableError(RustWeldingError.genericError(message: "mock fail")) - static func initBlockMetadataDb(fsBlockDbRoot: URL) async throws -> Bool { - true - } - - static func writeBlocksMetadata(fsBlockDbRoot: URL, blocks: [ZcashLightClientKit.ZcashCompactBlock]) async throws -> Bool { - try writeBlocksMetadataResult() - } - - static func initAccountsTable(dbData: URL, ufvks: [ZcashLightClientKit.UnifiedFullViewingKey], networkType: ZcashLightClientKit.NetworkType) async throws { } - - static func createToAddress(dbData: URL, usk: ZcashLightClientKit.UnifiedSpendingKey, to address: String, value: Int64, memo: ZcashLightClientKit.MemoBytes?, spendParamsPath: String, outputParamsPath: String, networkType: ZcashLightClientKit.NetworkType) async -> Int64 { - -1 - } - - static func shieldFunds( - dbData: URL, - usk: ZcashLightClientKit.UnifiedSpendingKey, - memo: ZcashLightClientKit.MemoBytes?, - shieldingThreshold: Zatoshi, - spendParamsPath: String, - outputParamsPath: String, - networkType: ZcashLightClientKit.NetworkType - ) async -> Int64 { - -1 - } - - static func getAddressMetadata(_ address: String) -> ZcashLightClientKit.AddressMetadata? { - nil - } - - static func clearUtxos(dbData: URL, address: ZcashLightClientKit.TransparentAddress, sinceHeight: ZcashLightClientKit.BlockHeight, networkType: ZcashLightClientKit.NetworkType) async throws -> Int32 { - 0 - } - - static func getTransparentBalance(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> Int64 { 0 } - - static func getVerifiedTransparentBalance(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> Int64 { 0 } - - static func listTransparentReceivers(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> [ZcashLightClientKit.TransparentAddress] { - [] - } - - static func deriveUnifiedFullViewingKey(from spendingKey: ZcashLightClientKit.UnifiedSpendingKey, networkType: ZcashLightClientKit.NetworkType) throws -> ZcashLightClientKit.UnifiedFullViewingKey { - throw KeyDerivationErrors.unableToDerive - } - - static func deriveUnifiedSpendingKey(from seed: [UInt8], accountIndex: Int32, networkType: ZcashLightClientKit.NetworkType) throws -> ZcashLightClientKit.UnifiedSpendingKey { - throw KeyDerivationErrors.unableToDerive - } - - static func getCurrentAddress(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> ZcashLightClientKit.UnifiedAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func getNextAvailableAddress(dbData: URL, account: Int32, networkType: ZcashLightClientKit.NetworkType) async throws -> ZcashLightClientKit.UnifiedAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func getSaplingReceiver(for uAddr: ZcashLightClientKit.UnifiedAddress) throws -> ZcashLightClientKit.SaplingAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func getTransparentReceiver(for uAddr: ZcashLightClientKit.UnifiedAddress) throws -> ZcashLightClientKit.TransparentAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func shieldFunds(dbData: URL, usk: ZcashLightClientKit.UnifiedSpendingKey, memo: ZcashLightClientKit.MemoBytes, spendParamsPath: String, outputParamsPath: String, networkType: ZcashLightClientKit.NetworkType) async -> Int64 { - -1 - } - - static func receiverTypecodesOnUnifiedAddress(_ address: String) throws -> [UInt32] { - throw KeyDerivationErrors.receiverNotFound - } - - static func createAccount(dbData: URL, seed: [UInt8], networkType: ZcashLightClientKit.NetworkType) async throws -> ZcashLightClientKit.UnifiedSpendingKey { - throw KeyDerivationErrors.unableToDerive - } - - static func getReceivedMemo(dbData: URL, idNote: Int64, networkType: ZcashLightClientKit.NetworkType) async -> ZcashLightClientKit.Memo? { nil } - - static func getSentMemo(dbData: URL, idNote: Int64, networkType: ZcashLightClientKit.NetworkType) async -> ZcashLightClientKit.Memo? { nil } - - static func createToAddress(dbData: URL, usk: ZcashLightClientKit.UnifiedSpendingKey, to address: String, value: Int64, memo: ZcashLightClientKit.MemoBytes, spendParamsPath: String, outputParamsPath: String, networkType: ZcashLightClientKit.NetworkType) async -> Int64 { - -1 - } - - static func initDataDb(dbData: URL, seed: [UInt8]?, networkType: ZcashLightClientKit.NetworkType) async throws -> ZcashLightClientKit.DbInitResult { - .seedRequired - } - - static func deriveSaplingAddressFromViewingKey(_ extfvk: ZcashLightClientKit.SaplingExtendedFullViewingKey, networkType: ZcashLightClientKit.NetworkType) throws -> ZcashLightClientKit.SaplingAddress { - throw RustWeldingError.unableToDeriveKeys - } - - static func isValidSaplingExtendedSpendingKey(_ key: String, networkType: ZcashLightClientKit.NetworkType) -> Bool { false } - - static func deriveSaplingExtendedFullViewingKeys(seed: [UInt8], accounts: Int32, networkType: ZcashLightClientKit.NetworkType) throws -> [ZcashLightClientKit.SaplingExtendedFullViewingKey]? { - nil - } - - static func isValidUnifiedAddress(_ address: String, networkType: ZcashLightClientKit.NetworkType) -> Bool { - false - } - - static func deriveSaplingExtendedFullViewingKey(_ spendingKey: SaplingExtendedSpendingKey, networkType: ZcashLightClientKit.NetworkType) throws -> ZcashLightClientKit.SaplingExtendedFullViewingKey? { - nil - } - - public func deriveViewingKeys(seed: [UInt8], numberOfAccounts: Int) throws -> [UnifiedFullViewingKey] { [] } - - static func getNearestRewindHeight(dbData: URL, height: Int32, networkType: NetworkType) async -> Int32 { -1 } - - static func network(dbData: URL, address: String, sinceHeight: BlockHeight, networkType: NetworkType) async throws -> Int32 { -1 } - - static func initAccountsTable(dbData: URL, ufvks: [UnifiedFullViewingKey], networkType: NetworkType) async throws -> Bool { false } - - static func putUnspentTransparentOutput( - dbData: URL, - txid: [UInt8], - index: Int, - script: [UInt8], - value: Int64, - height: BlockHeight, - networkType: NetworkType - ) async throws -> Bool { - false - } - - static func downloadedUtxoBalance(dbData: URL, address: String, networkType: NetworkType) async throws -> WalletBalance { - throw RustWeldingError.genericError(message: "unimplemented") - } - - static func createToAddress( - dbData: URL, - account: Int32, - extsk: String, - to address: String, - value: Int64, - memo: String?, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 { - -1 - } - - static func deriveTransparentAddressFromSeed(seed: [UInt8], account: Int, index: Int, networkType: NetworkType) throws -> TransparentAddress { - throw KeyDerivationErrors.unableToDerive - } - - static func deriveUnifiedFullViewingKeyFromSeed(_ seed: [UInt8], numberOfAccounts: Int32, networkType: NetworkType) throws -> [UnifiedFullViewingKey] { - throw KeyDerivationErrors.unableToDerive - } - - static func isValidSaplingExtendedFullViewingKey(_ key: String, networkType: NetworkType) -> Bool { false } - - static func isValidUnifiedFullViewingKey(_ ufvk: String, networkType: NetworkType) -> Bool { false } - - static func deriveSaplingExtendedSpendingKeys(seed: [UInt8], accounts: Int32, networkType: NetworkType) throws -> [SaplingExtendedSpendingKey]? { nil } - - static func consensusBranchIdFor(height: Int32, networkType: NetworkType) throws -> Int32 { - guard let consensus = consensusBranchID else { - return try rustBackend.consensusBranchIdFor(height: height, networkType: networkType) + await rustBackendMock.setInitDataDbSeedClosure() { seed in + return try await rustBackend.initDataDb(seed: seed) } - return consensus - } - static func lastError() -> RustWeldingError? { - mockError ?? rustBackend.lastError() - } - - static func getLastError() -> String? { - mockLastError ?? rustBackend.getLastError() - } - - static func isValidSaplingAddress(_ address: String, networkType: NetworkType) -> Bool { - true - } - - static func isValidTransparentAddress(_ address: String, networkType: NetworkType) -> Bool { - true - } - - static func initDataDb(dbData: URL, networkType: NetworkType) async throws { - if !mockDataDb { - _ = try await rustBackend.initDataDb(dbData: dbData, seed: nil, networkType: networkType) - } - } - - static func initBlocksTable( - dbData: URL, - height: Int32, - hash: String, - time: UInt32, - saplingTree: String, - networkType: NetworkType - ) async throws { - if !mockDataDb { + await rustBackendMock.setInitBlocksTableHeightHashTimeSaplingTreeClosure() { height, hash, time, saplingTree in try await rustBackend.initBlocksTable( - dbData: dbData, height: height, hash: hash, time: time, - saplingTree: saplingTree, - networkType: networkType + saplingTree: saplingTree ) } - } - - static func getBalance(dbData: URL, account: Int32, networkType: NetworkType) async throws -> Int64 { - if let balance = mockBalance { - return balance - } - return try await rustBackend.getBalance(dbData: dbData, account: account, networkType: networkType) - } - - static func getVerifiedBalance(dbData: URL, account: Int32, networkType: NetworkType) async throws -> Int64 { - if let balance = mockVerifiedBalance { - return balance + + await rustBackendMock.setGetBalanceAccountClosure() { account in + return try await rustBackend.getBalance(account: account) } - return try await rustBackend.getVerifiedBalance(dbData: dbData, account: account, networkType: networkType) - } - - static func validateCombinedChain(fsBlockDbRoot: URL, dbData: URL, networkType: NetworkType, limit: UInt32 = 0) async -> Int32 { - if let rate = self.mockValidateCombinedChainSuccessRate { - if shouldSucceed(successRate: rate) { - return await validationResult(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) - } else { - return Int32(mockValidateCombinedChainFailureHeight) - } - } else if let attempts = self.mockValidateCombinedChainFailAfterAttempts { - self.mockValidateCombinedChainFailAfterAttempts = attempts - 1 - if attempts > 0 { - return await validationResult(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) - } else { - if attempts == 0 { - return Int32(mockValidateCombinedChainFailureHeight) - } else if attempts < 0 && mockValidateCombinedChainKeepFailing { - return Int32(mockValidateCombinedChainFailureHeight) + await rustBackendMock.setGetVerifiedBalanceAccountClosure() { account in + return try await rustBackend.getVerifiedBalance(account: account) + } + + await rustBackendMock.setValidateCombinedChainLimitClosure() { [weak self] limit in + guard let self else { throw RustWeldingError.genericError(message: "Self is nil") } + if let rate = mockValidateCombinedChainSuccessRate { + if Self.shouldSucceed(successRate: rate) { + return try await rustBackend.validateCombinedChain(limit: limit) } else { - return await validationResult(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) + throw mockValidateCombinedChainFailureError + } + } else if let attempts = self.mockValidateCombinedChainFailAfterAttempts { + self.mockValidateCombinedChainFailAfterAttempts = attempts - 1 + if attempts > 0 { + return try await rustBackend.validateCombinedChain(limit: limit) + } else { + if attempts == 0 { + throw mockValidateCombinedChainFailureError + } else if attempts < 0 && mockValidateCombinedChainKeepFailing { + throw mockValidateCombinedChainFailureError + } else { + return try await rustBackend.validateCombinedChain(limit: limit) + } } - } - } - return await rustBackend.validateCombinedChain(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) - } - - private static func validationResult(fsBlockDbRoot: URL, dbData: URL, networkType: NetworkType) async -> Int32 { - if mockDataDb { - return -1 - } else { - return await rustBackend.validateCombinedChain(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) - } - } - - static func rewindToHeight(dbData: URL, height: Int32, networkType: NetworkType) async -> Bool { - mockDataDb ? true : rustBackend.rewindToHeight(dbData: dbData, height: height, networkType: networkType) - } - - static func scanBlocks(fsBlockDbRoot: URL, dbData: URL, limit: UInt32, networkType: NetworkType) async -> Bool { - if let rate = mockScanblocksSuccessRate { - if shouldSucceed(successRate: rate) { - return mockDataDb ? true : await rustBackend.scanBlocks(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: networkType) } else { - return false + return try await rustBackend.validateCombinedChain(limit: limit) } } - return await rustBackend.scanBlocks(fsBlockDbRoot: fsBlockDbRoot, dbData: dbData, networkType: Self.networkType) + + await rustBackendMock.setRewindToHeightHeightClosure() { height in + try await rustBackend.rewindToHeight(height: height) + } + + await rustBackendMock.setRewindCacheToHeightHeightClosure() { _ in } + + await rustBackendMock.setScanBlocksLimitClosure() { limit in + try await rustBackend.scanBlocks(limit: limit) + } } - - static func createToAddress( - dbData: URL, - account: Int32, - extsk: String, - consensusBranchId: Int32, - to address: String, - value: Int64, - memo: String?, - spendParamsPath: String, - outputParamsPath: String, - networkType: NetworkType - ) async -> Int64 { - -1 - } - - static func shouldSucceed(successRate: Float) -> Bool { + + private static func shouldSucceed(successRate: Float) -> Bool { let random = Float.random(in: 0.0...1.0) return random <= successRate } - - static func deriveExtendedFullViewingKey(_ spendingKey: String, networkType: NetworkType) throws -> String? { - nil - } - - static func deriveExtendedFullViewingKeys(seed: String, accounts: Int32, networkType: NetworkType) throws -> [String]? { - nil - } - - static func deriveExtendedSpendingKeys(seed: String, accounts: Int32, networkType: NetworkType) throws -> [String]? { - nil - } - - static func decryptAndStoreTransaction(dbData: URL, txBytes: [UInt8], minedHeight: Int32, networkType: NetworkType) async -> Bool { - false - } } extension SaplingParamsSourceURL { diff --git a/Tests/TestUtils/SynchronizerMock.swift b/Tests/TestUtils/SynchronizerMock.swift deleted file mode 100644 index 2f168535..00000000 --- a/Tests/TestUtils/SynchronizerMock.swift +++ /dev/null @@ -1,174 +0,0 @@ -// -// SynchronizerMock.swift -// -// -// Created by Michal Fousek on 20.03.2023. -// - -import Combine -import Foundation -@testable import ZcashLightClientKit - -class SynchronizerMock: Synchronizer { - init() { } - - var underlyingAlias: ZcashSynchronizerAlias! = nil - var alias: ZcashLightClientKit.ZcashSynchronizerAlias { underlyingAlias } - - var underlyingStateStream: AnyPublisher! = nil - var stateStream: AnyPublisher { underlyingStateStream } - - var underlyingLatestState: SynchronizerState! = nil - var latestState: SynchronizerState { underlyingLatestState } - - var underlyingEventStream: AnyPublisher! = nil - var eventStream: AnyPublisher { underlyingEventStream } - - var underlyingConnectionState: ConnectionState! = nil - var connectionState: ConnectionState { underlyingConnectionState } - - let metrics = SDKMetrics() - - var prepareWithSeedViewingKeysWalletBirthdayClosure: ( - ([UInt8]?, [UnifiedFullViewingKey], BlockHeight) async throws -> Initializer.InitializationResult - )! = nil - func prepare( - with seed: [UInt8]?, - viewingKeys: [UnifiedFullViewingKey], - walletBirthday: BlockHeight - ) async throws -> Initializer.InitializationResult { - return try await prepareWithSeedViewingKeysWalletBirthdayClosure(seed, viewingKeys, walletBirthday) - } - - var startRetryClosure: ((Bool) async throws -> Void)! = nil - func start(retry: Bool) async throws { - try await startRetryClosure(retry) - } - - var stopClosure: (() async -> Void)! = nil - func stop() async { - await stopClosure() - } - - var getSaplingAddressAccountIndexClosure: ((Int) async throws -> SaplingAddress)! = nil - func getSaplingAddress(accountIndex: Int) async throws -> SaplingAddress { - return try await getSaplingAddressAccountIndexClosure(accountIndex) - } - - var getUnifiedAddressAccountIndexClosure: ((Int) async throws -> UnifiedAddress)! = nil - func getUnifiedAddress(accountIndex: Int) async throws -> UnifiedAddress { - return try await getUnifiedAddressAccountIndexClosure(accountIndex) - } - - var getTransparentAddressAccountIndexClosure: ((Int) async throws -> TransparentAddress)! = nil - func getTransparentAddress(accountIndex: Int) async throws -> TransparentAddress { - return try await getTransparentAddressAccountIndexClosure(accountIndex) - } - - var sendToAddressSpendingKeyZatoshiToAddressMemoClosure: ( - (UnifiedSpendingKey, Zatoshi, Recipient, Memo?) async throws -> PendingTransactionEntity - )! = nil - func sendToAddress(spendingKey: UnifiedSpendingKey, zatoshi: Zatoshi, toAddress: Recipient, memo: Memo?) async throws -> PendingTransactionEntity { - return try await sendToAddressSpendingKeyZatoshiToAddressMemoClosure(spendingKey, zatoshi, toAddress, memo) - } - - var shieldFundsSpendingKeyMemoShieldingThresholdClosure: ((UnifiedSpendingKey, Memo, Zatoshi) async throws -> PendingTransactionEntity)! = nil - func shieldFunds(spendingKey: UnifiedSpendingKey, memo: Memo, shieldingThreshold: Zatoshi) async throws -> PendingTransactionEntity { - return try await shieldFundsSpendingKeyMemoShieldingThresholdClosure(spendingKey, memo, shieldingThreshold) - } - - var cancelSpendTransactionClosure: ((PendingTransactionEntity) async -> Bool)! = nil - func cancelSpend(transaction: PendingTransactionEntity) async -> Bool { - return await cancelSpendTransactionClosure(transaction) - } - - var underlyingPendingTransactions: [PendingTransactionEntity]! = nil - var pendingTransactions: [PendingTransactionEntity] { - get async { underlyingPendingTransactions } - } - - var underlyingClearedTransactions: [ZcashTransaction.Overview]! = nil - var clearedTransactions: [ZcashTransaction.Overview] { - get async { underlyingClearedTransactions } - } - - var underlyingSentTransactions: [ZcashTransaction.Sent]! = nil - var sentTransactions: [ZcashTransaction.Sent] { - get async { underlyingSentTransactions } - } - - var underlyingReceivedTransactions: [ZcashTransaction.Received]! = nil - var receivedTransactions: [ZcashTransaction.Received] { - get async { underlyingReceivedTransactions } - } - - var paginatedTransactionsOfKindClosure: ((TransactionKind) -> PaginatedTransactionRepository)! = nil - func paginatedTransactions(of kind: TransactionKind) -> PaginatedTransactionRepository { - return paginatedTransactionsOfKindClosure(kind) - } - - var getMemosForTransactionClosure: ((ZcashTransaction.Overview) async throws -> [Memo])! = nil - func getMemos(for transaction: ZcashTransaction.Overview) async throws -> [Memo] { - return try await getMemosForTransactionClosure(transaction) - } - - var getMemosForReceivedTransactionClosure: ((ZcashTransaction.Received) async throws -> [Memo])! = nil - func getMemos(for receivedTransaction: ZcashTransaction.Received) async throws -> [Memo] { - return try await getMemosForReceivedTransactionClosure(receivedTransaction) - } - - var getMemosForSentTransactionClosure: ((ZcashTransaction.Sent) async throws -> [Memo])! = nil - func getMemos(for sentTransaction: ZcashTransaction.Sent) async throws -> [Memo] { - return try await getMemosForSentTransactionClosure(sentTransaction) - } - - var getRecipientsForClearedTransactionClosure: ((ZcashTransaction.Overview) async -> [TransactionRecipient])! = nil - func getRecipients(for transaction: ZcashTransaction.Overview) async -> [TransactionRecipient] { - return await getRecipientsForClearedTransactionClosure(transaction) - } - - var getRecipientsForSentTransactionClosure: ((ZcashTransaction.Sent) async -> [TransactionRecipient])! = nil - func getRecipients(for transaction: ZcashTransaction.Sent) async -> [TransactionRecipient] { - return await getRecipientsForSentTransactionClosure(transaction) - } - - var allConfirmedTransactionsFromTransactionClosure: ((ZcashTransaction.Overview, Int) async throws -> [ZcashTransaction.Overview])! = nil - func allConfirmedTransactions(from transaction: ZcashTransaction.Overview, limit: Int) async throws -> [ZcashTransaction.Overview] { - return try await allConfirmedTransactionsFromTransactionClosure(transaction, limit) - } - - var latestHeightClosure: (() async throws -> BlockHeight)! = nil - func latestHeight() async throws -> BlockHeight { - return try await latestHeightClosure() - } - - var refreshUTXOsAddressFromHeightClosure: ((TransparentAddress, BlockHeight) async throws -> RefreshedUTXOs)! = nil - func refreshUTXOs(address: TransparentAddress, from height: BlockHeight) async throws -> RefreshedUTXOs { - return try await refreshUTXOsAddressFromHeightClosure(address, height) - } - - var getTransparentBalanceAccountIndexClosure: ((Int) async throws -> WalletBalance)! = nil - func getTransparentBalance(accountIndex: Int) async throws -> WalletBalance { - return try await getTransparentBalanceAccountIndexClosure(accountIndex) - } - - var getShieldedBalanceAccountIndexClosure: ((Int) async throws -> Zatoshi)! = nil - func getShieldedBalance(accountIndex: Int) async throws -> Zatoshi { - try await getShieldedBalanceAccountIndexClosure(accountIndex) - } - - var getShieldedVerifiedBalanceAccountIndexClosure: ((Int) async throws -> Zatoshi)! = nil - func getShieldedVerifiedBalance(accountIndex: Int) async throws -> Zatoshi { - try await getShieldedVerifiedBalanceAccountIndexClosure(accountIndex) - } - - var rewindPolicyClosure: ((RewindPolicy) -> AnyPublisher)! = nil - func rewind(_ policy: RewindPolicy) -> AnyPublisher { - return rewindPolicyClosure(policy) - } - - var wipeClosure: (() -> AnyPublisher)! = nil - func wipe() -> AnyPublisher { - return wipeClosure() - } -} diff --git a/Tests/TestUtils/TestCoordinator.swift b/Tests/TestUtils/TestCoordinator.swift index a37d1ed2..6c50ce06 100644 --- a/Tests/TestUtils/TestCoordinator.swift +++ b/Tests/TestUtils/TestCoordinator.swift @@ -52,56 +52,19 @@ class TestCoordinator { singleCallTimeoutInMillis: 10000, streamingCallTimeoutInMillis: 1000000 ) - - convenience init( + + init( alias: ZcashSynchronizerAlias = .default, walletBirthday: BlockHeight, network: ZcashNetwork, callPrepareInConstructor: Bool = true, endpoint: LightWalletEndpoint = TestCoordinator.defaultEndpoint, syncSessionIDGenerator: SyncSessionIDGenerator = UniqueSyncSessionIDGenerator() - ) async throws { - let derivationTool = DerivationTool(networkType: network.networkType) - - let spendingKey = try derivationTool.deriveUnifiedSpendingKey( - seed: Environment.seedBytes, - accountIndex: 0 - ) - - let ufvk = try derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) - - try await self.init( - alias: alias, - spendingKey: spendingKey, - unifiedFullViewingKey: ufvk, - walletBirthday: walletBirthday, - network: network, - callPrepareInConstructor: callPrepareInConstructor, - endpoint: endpoint, - syncSessionIDGenerator: syncSessionIDGenerator - ) - } - - required init( - alias: ZcashSynchronizerAlias = .default, - spendingKey: UnifiedSpendingKey, - unifiedFullViewingKey: UnifiedFullViewingKey, - walletBirthday: BlockHeight, - network: ZcashNetwork, - callPrepareInConstructor: Bool = true, - endpoint: LightWalletEndpoint = TestCoordinator.defaultEndpoint, - syncSessionIDGenerator: SyncSessionIDGenerator ) async throws { await InternalSyncProgress(alias: alias, storage: UserDefaults.standard, logger: logger).rewind(to: 0) - self.spendingKey = spendingKey - self.viewingKey = unifiedFullViewingKey - self.birthday = walletBirthday - self.databases = TemporaryDbBuilder.build() - self.network = network - - let liveService = LightWalletServiceFactory(endpoint: endpoint).make() - self.service = DarksideWalletService(endpoint: endpoint, service: liveService) + let databases = TemporaryDbBuilder.build() + self.databases = databases let initializer = Initializer( cacheDbURL: nil, @@ -117,9 +80,21 @@ class TestCoordinator { logLevel: .debug ) - let synchronizer = SDKSynchronizer(initializer: initializer, sessionGenerator: syncSessionIDGenerator, sessionTicker: .live) - - self.synchronizer = synchronizer + let derivationTool = initializer.makeDerivationTool() + + self.spendingKey = try await derivationTool.deriveUnifiedSpendingKey( + seed: Environment.seedBytes, + accountIndex: 0 + ) + + self.viewingKey = try await derivationTool.deriveUnifiedFullViewingKey(from: spendingKey) + self.birthday = walletBirthday + self.network = network + + let liveService = LightWalletServiceFactory(endpoint: endpoint).make() + self.service = DarksideWalletService(endpoint: endpoint, service: liveService) + + self.synchronizer = SDKSynchronizer(initializer: initializer, sessionGenerator: syncSessionIDGenerator, sessionTicker: .live) subscribeToState(synchronizer: self.synchronizer) if callPrepareInConstructor { diff --git a/Tests/TestUtils/TestDbBuilder.swift b/Tests/TestUtils/TestDbBuilder.swift index 587fc8e7..62d201bb 100644 --- a/Tests/TestUtils/TestDbBuilder.swift +++ b/Tests/TestUtils/TestDbBuilder.swift @@ -58,16 +58,12 @@ enum TestDbBuilder { Bundle.module.url(forResource: "darkside_caches", withExtension: "db") } - static func prepopulatedDataDbProvider() async throws -> ConnectionProvider? { + static func prepopulatedDataDbProvider(rustBackend: ZcashRustBackendWelding) async throws -> ConnectionProvider? { guard let url = prePopulatedMainnetDataDbURL() else { return nil } let provider = SimpleConnectionProvider(path: url.absoluteString, readonly: true) - let initResult = try await ZcashRustBackend.initDataDb( - dbData: url, - seed: Environment.seedBytes, - networkType: .mainnet - ) + let initResult = try await rustBackend.initDataDb(seed: Environment.seedBytes) switch initResult { case .success: return provider @@ -76,19 +72,19 @@ enum TestDbBuilder { } } - static func transactionRepository() async throws -> TransactionRepository? { - guard let provider = try await prepopulatedDataDbProvider() else { return nil } + static func transactionRepository(rustBackend: ZcashRustBackendWelding) async throws -> TransactionRepository? { + guard let provider = try await prepopulatedDataDbProvider(rustBackend: rustBackend) else { return nil } return TransactionSQLDAO(dbProvider: provider) } - static func sentNotesRepository() async throws -> SentNotesRepository? { - guard let provider = try await prepopulatedDataDbProvider() else { return nil } + static func sentNotesRepository(rustBackend: ZcashRustBackendWelding) async throws -> SentNotesRepository? { + guard let provider = try await prepopulatedDataDbProvider(rustBackend: rustBackend) else { return nil } return SentNotesSQLDAO(dbProvider: provider) } - static func receivedNotesRepository() async throws -> ReceivedNoteRepository? { - guard let provider = try await prepopulatedDataDbProvider() else { return nil } + static func receivedNotesRepository(rustBackend: ZcashRustBackendWelding) async throws -> ReceivedNoteRepository? { + guard let provider = try await prepopulatedDataDbProvider(rustBackend: rustBackend) else { return nil } return ReceivedNotesSQLDAO(dbProvider: provider) } diff --git a/Tests/TestUtils/Tests+Utils.swift b/Tests/TestUtils/Tests+Utils.swift index 0312b13b..961f689f 100644 --- a/Tests/TestUtils/Tests+Utils.swift +++ b/Tests/TestUtils/Tests+Utils.swift @@ -9,10 +9,10 @@ import Combine import Foundation import GRPC -import ZcashLightClientKit import XCTest import NIO import NIOTransportServices +@testable import ZcashLightClientKit enum Environment { static let lightwalletdKey = "LIGHTWALLETD_ADDRESS" @@ -28,6 +28,11 @@ enum Environment { } static let testRecipientAddress = "zs17mg40levjezevuhdp5pqrd52zere7r7vrjgdwn5sj4xsqtm20euwahv9anxmwr3y3kmwuz8k55a" + + static var uniqueTestTempDirectory: URL { + URL(fileURLWithPath: NSString(string: NSTemporaryDirectory()) + .appendingPathComponent("tmp-\(Int.random(in: 0 ... .max))")) + } } public enum Constants { @@ -128,3 +133,21 @@ func parametersReady() -> Bool { return true } + +extension ZcashRustBackend { + static func makeForTests( + dbData: URL = try! __dataDbURL(), + fsBlockDbRoot: URL, + spendParamsPath: URL = SaplingParamsSourceURL.default.spendParamFileURL, + outputParamsPath: URL = SaplingParamsSourceURL.default.outputParamFileURL, + networkType: NetworkType + ) -> ZcashRustBackendWelding { + ZcashRustBackend( + dbData: dbData, + fsBlockDbRoot: fsBlockDbRoot, + spendParamsPath: spendParamsPath, + outputParamsPath: outputParamsPath, + networkType: networkType + ) + } +} diff --git a/Tests/TestUtils/AlternativeSynchronizerAPITestsData.swift b/Tests/TestUtils/TestsData.swift similarity index 64% rename from Tests/TestUtils/AlternativeSynchronizerAPITestsData.swift rename to Tests/TestUtils/TestsData.swift index 9838a881..74ba8f16 100644 --- a/Tests/TestUtils/AlternativeSynchronizerAPITestsData.swift +++ b/Tests/TestUtils/TestsData.swift @@ -1,5 +1,5 @@ // -// AlternativeSynchronizerAPITestsData.swift +// TestsData.swift // // // Created by Michal Fousek on 20.03.2023. @@ -8,13 +8,29 @@ import Foundation @testable import ZcashLightClientKit -class AlternativeSynchronizerAPITestsData { - let derivationTools = DerivationTool(networkType: .testnet) +class TestsData { + let networkType: NetworkType + + lazy var initialier = { + Initializer( + cacheDbURL: nil, + fsBlockDbRoot: URL(fileURLWithPath: "/"), + dataDbURL: URL(fileURLWithPath: "/"), + pendingDbURL: URL(fileURLWithPath: "/"), + endpoint: LightWalletEndpointBuilder.default, + network: ZcashNetworkBuilder.network(for: networkType), + spendParamsURL: URL(fileURLWithPath: "/"), + outputParamsURL: URL(fileURLWithPath: "/"), + saplingParamsSourceURL: .default + ) + }() + lazy var derivationTools: DerivationTool = { initialier.makeDerivationTool() }() let saplingAddress = SaplingAddress(validatedEncoding: "ztestsapling1ctuamfer5xjnnrdr3xdazenljx0mu0gutcf9u9e74tr2d3jwjnt0qllzxaplu54hgc2tyjdc2p6") let unifiedAddress = UnifiedAddress( validatedEncoding: """ u1l9f0l4348negsncgr9pxd9d3qaxagmqv3lnexcplmufpq7muffvfaue6ksevfvd7wrz7xrvn95rc5zjtn7ugkmgh5rnxswmcj30y0pw52pn0zjvy38rn2esfgve64rj5pcmazxgpyuj - """ + """, + networkType: .testnet ) let transparentAddress = TransparentAddress(validatedEncoding: "t1dRJRY7GmyeykJnMH38mdQoaZtFhn1QmGz") lazy var pendingTransactionEntity = { @@ -73,9 +89,19 @@ class AlternativeSynchronizerAPITestsData { }() var seed: [UInt8] = Environment.seedBytes - lazy var spendingKey: UnifiedSpendingKey = { try! derivationTools.deriveUnifiedSpendingKey(seed: seed, accountIndex: 0) }() - lazy var viewingKey: UnifiedFullViewingKey = { try! derivationTools.deriveUnifiedFullViewingKey(from: spendingKey) }() + var spendingKey: UnifiedSpendingKey { + get async { + try! await derivationTools.deriveUnifiedSpendingKey(seed: seed, accountIndex: 0) + } + } + var viewingKey: UnifiedFullViewingKey { + get async { + try! await derivationTools.deriveUnifiedFullViewingKey(from: spendingKey) + } + } var birthday: BlockHeight = 123000 - init() { } + init(networkType: NetworkType) { + self.networkType = networkType + } }