diff --git a/MIGRATIONS.md b/MIGRATIONS.md index be382c77..50a0ae2f 100644 --- a/MIGRATIONS.md +++ b/MIGRATIONS.md @@ -5,6 +5,8 @@ Upcoming Migrating to Version 1.4.* from 1.3.* -------------------------------------- Various APIs that have always been considered private have been moved into a new package called `internal`. While this should not be a breaking change, clients that might have relied on these internal classes should stop doing so. +A number of methods have been converted to suspending functions, because they were performing slow or blocking calls (e.g. disk IO) internally. This is a breaking change. + Migrating to Version 1.3.* from 1.2.* -------------------------------------- The biggest breaking changes in 1.3 that inspired incrementing the minor version number was simplifying down to one "network aware" library rather than two separate libraries, each dedicated to either testnet or mainnet. This greatly simplifies the gradle configuration and has lots of other benefits. Wallets can now set a network with code similar to the following: diff --git a/darkside-test-lib/src/androidTest/java/cash/z/ecc/android/sdk/darkside/test/DarksideTestCoordinator.kt b/darkside-test-lib/src/androidTest/java/cash/z/ecc/android/sdk/darkside/test/DarksideTestCoordinator.kt index 1bd8a607..c54ab58b 100644 --- a/darkside-test-lib/src/androidTest/java/cash/z/ecc/android/sdk/darkside/test/DarksideTestCoordinator.kt +++ b/darkside-test-lib/src/androidTest/java/cash/z/ecc/android/sdk/darkside/test/DarksideTestCoordinator.kt @@ -141,8 +141,10 @@ class DarksideTestCoordinator(val wallet: TestWallet) { inner class DarksideTestValidator { fun validateHasBlock(height: Int) { - assertTrue((synchronizer as SdkSynchronizer).findBlockHashAsHex(height) != null) - assertTrue((synchronizer as SdkSynchronizer).findBlockHash(height)?.size ?: 0 > 0) + runBlocking { + assertTrue((synchronizer as SdkSynchronizer).findBlockHashAsHex(height) != null) + assertTrue((synchronizer as SdkSynchronizer).findBlockHash(height)?.size ?: 0 > 0) + } } fun validateLatestHeight(height: Int) = runBlocking { @@ -185,7 +187,7 @@ class DarksideTestCoordinator(val wallet: TestWallet) { } fun validateBlockHash(height: Int, expectedHash: String) { - val hash = (synchronizer as SdkSynchronizer).findBlockHashAsHex(height) + val hash = runBlocking { (synchronizer as SdkSynchronizer).findBlockHashAsHex(height) } assertEquals(expectedHash, hash) } @@ -194,7 +196,7 @@ class DarksideTestCoordinator(val wallet: TestWallet) { } fun validateTxCount(count: Int) { - val txCount = (synchronizer as SdkSynchronizer).getTransactionCount() + val txCount = runBlocking { (synchronizer as SdkSynchronizer).getTransactionCount() } assertEquals("Expected $count transactions but found $txCount instead!", count, txCount) } diff --git a/darkside-test-lib/src/androidTest/java/cash/z/ecc/android/sdk/darkside/test/TestWallet.kt b/darkside-test-lib/src/androidTest/java/cash/z/ecc/android/sdk/darkside/test/TestWallet.kt index fc9e37aa..e63ca920 100644 --- a/darkside-test-lib/src/androidTest/java/cash/z/ecc/android/sdk/darkside/test/TestWallet.kt +++ b/darkside-test-lib/src/androidTest/java/cash/z/ecc/android/sdk/darkside/test/TestWallet.kt @@ -23,6 +23,7 @@ import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.takeWhile import kotlinx.coroutines.launch import kotlinx.coroutines.newFixedThreadPoolContext +import kotlinx.coroutines.runBlocking import java.util.concurrent.TimeoutException /** @@ -51,19 +52,29 @@ class TestWallet( val walletScope = CoroutineScope( SupervisorJob() + newFixedThreadPoolContext(3, this.javaClass.simpleName) ) + + // Although runBlocking isn't great, this usage is OK because this is only used within the + // automated tests + private val context = InstrumentationRegistry.getInstrumentation().context private val seed: ByteArray = Mnemonics.MnemonicCode(seedPhrase).toSeed() - private val shieldedSpendingKey = DerivationTool.deriveSpendingKeys(seed, network = network)[0] - private val transparentSecretKey = DerivationTool.deriveTransparentSecretKey(seed, network = network) - val initializer = Initializer(context) { config -> - config.importWallet(seed, startHeight, network, host, alias = alias) + private val shieldedSpendingKey = + runBlocking { DerivationTool.deriveSpendingKeys(seed, network = network)[0] } + private val transparentSecretKey = + runBlocking { DerivationTool.deriveTransparentSecretKey(seed, network = network) } + val initializer = runBlocking { + Initializer.new(context) { config -> + runBlocking { config.importWallet(seed, startHeight, network, host, alias = alias) } + } } val synchronizer: SdkSynchronizer = Synchronizer(initializer) as SdkSynchronizer val service = (synchronizer.processor.downloader.lightWalletService as LightWalletGrpcService) val available get() = synchronizer.saplingBalances.value.availableZatoshi - val shieldedAddress = DerivationTool.deriveShieldedAddress(seed, network = network) - val transparentAddress = DerivationTool.deriveTransparentAddress(seed, network = network) + val shieldedAddress = + runBlocking { DerivationTool.deriveShieldedAddress(seed, network = network) } + val transparentAddress = + runBlocking { DerivationTool.deriveTransparentAddress(seed, network = network) } val birthdayHeight get() = synchronizer.latestBirthdayHeight val networkName get() = synchronizer.network.networkName val connectionInfo get() = service.connectionInfo.toString() diff --git a/demo-app/src/androidTest/java/cash/z/wallet/sdk/sample/demoapp/SampleCodeTest.kt b/demo-app/src/androidTest/java/cash/z/wallet/sdk/sample/demoapp/SampleCodeTest.kt index e0987349..a28913dd 100644 --- a/demo-app/src/androidTest/java/cash/z/wallet/sdk/sample/demoapp/SampleCodeTest.kt +++ b/demo-app/src/androidTest/java/cash/z/wallet/sdk/sample/demoapp/SampleCodeTest.kt @@ -52,7 +52,12 @@ class SampleCodeTest { // /////////////////////////////////////////////////// // Derive Extended Spending Key @Test fun deriveSpendingKey() { - val spendingKeys = DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.Mainnet) + val spendingKeys = runBlocking { + DerivationTool.deriveSpendingKeys( + seed, + ZcashNetwork.Mainnet + ) + } assertEquals(1, spendingKeys.size) log("Spending Key: ${spendingKeys?.get(0)}") } @@ -140,7 +145,7 @@ class SampleCodeTest { private val lightwalletdHost: String = ZcashNetwork.Mainnet.defaultHost private val context = InstrumentationRegistry.getInstrumentation().targetContext - private val synchronizer = Synchronizer(Initializer(context) {}) + private val synchronizer = Synchronizer(runBlocking { Initializer.new(context) {} }) @BeforeClass @JvmStatic diff --git a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getaddress/GetAddressFragment.kt b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getaddress/GetAddressFragment.kt index d90f0a69..6f45b4bc 100644 --- a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getaddress/GetAddressFragment.kt +++ b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getaddress/GetAddressFragment.kt @@ -2,6 +2,7 @@ package cash.z.ecc.android.sdk.demoapp.demos.getaddress import android.os.Bundle import android.view.LayoutInflater +import androidx.lifecycle.lifecycleScope import cash.z.ecc.android.bip39.Mnemonics import cash.z.ecc.android.bip39.toSeed import cash.z.ecc.android.sdk.demoapp.BaseDemoFragment @@ -11,6 +12,8 @@ import cash.z.ecc.android.sdk.demoapp.util.fromResources import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.type.UnifiedViewingKey import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking /** * Displays the address associated with the seed defined by the default config. To modify the seed @@ -34,14 +37,16 @@ class GetAddressFragment : BaseDemoFragment() { seed = Mnemonics.MnemonicCode(seedPhrase).toSeed() // the derivation tool can be used for generating keys and addresses - viewingKey = DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() + viewingKey = runBlocking { DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() } } private fun displayAddress() { // a full fledged app would just get the address from the synchronizer - val zaddress = DerivationTool.deriveShieldedAddress(seed, ZcashNetwork.fromResources(requireApplicationContext())) - val taddress = DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext())) - binding.textInfo.text = "z-addr:\n$zaddress\n\n\nt-addr:\n$taddress" + viewLifecycleOwner.lifecycleScope.launchWhenStarted { + val zaddress = DerivationTool.deriveShieldedAddress(seed, ZcashNetwork.fromResources(requireApplicationContext())) + val taddress = DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext())) + binding.textInfo.text = "z-addr:\n$zaddress\n\n\nt-addr:\n$taddress" + } } // TODO: show an example with the synchronizer @@ -65,10 +70,15 @@ class GetAddressFragment : BaseDemoFragment() { // override fun onActionButtonClicked() { - copyToClipboard( - DerivationTool.deriveShieldedAddress(viewingKey.extfvk, ZcashNetwork.fromResources(requireApplicationContext())), - "Shielded address copied to clipboard!" - ) + viewLifecycleOwner.lifecycleScope.launch { + copyToClipboard( + DerivationTool.deriveShieldedAddress( + viewingKey.extfvk, + ZcashNetwork.fromResources(requireApplicationContext()) + ), + "Shielded address copied to clipboard!" + ) + } } override fun inflateBinding(layoutInflater: LayoutInflater): FragmentGetAddressBinding = diff --git a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getbalance/GetBalanceFragment.kt b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getbalance/GetBalanceFragment.kt index 9c9251ee..658e3080 100644 --- a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getbalance/GetBalanceFragment.kt +++ b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getbalance/GetBalanceFragment.kt @@ -17,6 +17,7 @@ import cash.z.ecc.android.sdk.ext.convertZatoshiToZecString import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.type.WalletBalance import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.runBlocking /** * Displays the available balance && total balance associated with the seed defined by the default config. @@ -43,13 +44,13 @@ class GetBalanceFragment : BaseDemoFragment() { val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed() // converting seed into viewingKey - val viewingKey = DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() + val viewingKey = runBlocking { DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() } // using the ViewingKey to initialize - Initializer(requireApplicationContext()) { + runBlocking {Initializer.new(requireApplicationContext(), null) { it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext())) it.importWallet(viewingKey, network = ZcashNetwork.fromResources(requireApplicationContext())) - }.let { initializer -> + }}.let { initializer -> synchronizer = Synchronizer(initializer) } } diff --git a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getprivatekey/GetPrivateKeyFragment.kt b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getprivatekey/GetPrivateKeyFragment.kt index c5a4723f..ac863fe5 100644 --- a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getprivatekey/GetPrivateKeyFragment.kt +++ b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/getprivatekey/GetPrivateKeyFragment.kt @@ -2,6 +2,7 @@ package cash.z.ecc.android.sdk.demoapp.demos.getprivatekey import android.os.Bundle import android.view.LayoutInflater +import androidx.lifecycle.lifecycleScope import cash.z.ecc.android.bip39.Mnemonics import cash.z.ecc.android.bip39.toSeed import cash.z.ecc.android.sdk.demoapp.BaseDemoFragment @@ -10,6 +11,7 @@ import cash.z.ecc.android.sdk.demoapp.ext.requireApplicationContext import cash.z.ecc.android.sdk.demoapp.util.fromResources import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.launch /** * Displays the viewing key and spending key associated with the seed used during the demo. The @@ -37,13 +39,22 @@ class GetPrivateKeyFragment : BaseDemoFragment() { private fun displayKeys() { // derive the keys from the seed: // demonstrate deriving spending keys for five accounts but only take the first one - val spendingKey = DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext()), 5).first() + lifecycleScope.launchWhenStarted { + val spendingKey = DerivationTool.deriveSpendingKeys( + seed, + ZcashNetwork.fromResources(requireApplicationContext()), + 5 + ).first() - // derive the key that allows you to view but not spend transactions - val viewingKey = DerivationTool.deriveViewingKey(spendingKey, ZcashNetwork.fromResources(requireApplicationContext())) + // derive the key that allows you to view but not spend transactions + val viewingKey = DerivationTool.deriveViewingKey( + spendingKey, + ZcashNetwork.fromResources(requireApplicationContext()) + ) - // display the keys in the UI - binding.textInfo.setText("Spending Key:\n$spendingKey\n\nViewing Key:\n$viewingKey") + // display the keys in the UI + binding.textInfo.setText("Spending Key:\n$spendingKey\n\nViewing Key:\n$viewingKey") + } } // @@ -65,10 +76,15 @@ class GetPrivateKeyFragment : BaseDemoFragment() { // override fun onActionButtonClicked() { - copyToClipboard( - DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first().extpub, - "ViewingKey copied to clipboard!" - ) + lifecycleScope.launch { + copyToClipboard( + DerivationTool.deriveUnifiedViewingKeys( + seed, + ZcashNetwork.fromResources(requireApplicationContext()) + ).first().extpub, + "ViewingKey copied to clipboard!" + ) + } } override fun inflateBinding(layoutInflater: LayoutInflater): FragmentGetPrivateKeyBinding = diff --git a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/listtransactions/ListTransactionsFragment.kt b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/listtransactions/ListTransactionsFragment.kt index 6bb370c5..6a735e0f 100644 --- a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/listtransactions/ListTransactionsFragment.kt +++ b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/listtransactions/ListTransactionsFragment.kt @@ -19,6 +19,7 @@ import cash.z.ecc.android.sdk.ext.collectWith import cash.z.ecc.android.sdk.internal.twig import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.runBlocking /** * List all transactions related to the given seed, since the given birthday. This begins by @@ -47,11 +48,16 @@ class ListTransactionsFragment : BaseDemoFragment() { // Use a BIP-39 library to convert a seed phrase into a byte array. Most wallets already // have the seed stored seed = Mnemonics.MnemonicCode(sharedViewModel.seedPhrase.value).toSeed() - initializer = Initializer(requireApplicationContext()) { - it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) + initializer = runBlocking {Initializer.new(requireApplicationContext()) { + runBlocking { it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) } it.alias = "Demo_Utxos" - } + }} synchronizer = Synchronizer(initializer) } @@ -102,7 +103,7 @@ class ListUtxosFragment : BaseDemoFragment() { txids?.map { it.data.apply { try { - initializer.rustBackend.decryptAndStoreTransaction(toByteArray()) + runBlocking { initializer.rustBackend.decryptAndStoreTransaction(toByteArray()) } } catch (t: Throwable) { twig("failed to decrypt and store transaction due to: $t") } @@ -154,7 +155,9 @@ class ListUtxosFragment : BaseDemoFragment() { super.onResume() resetInBackground() val seed = Mnemonics.MnemonicCode(sharedViewModel.seedPhrase.value).toSeed() - binding.inputAddress.setText(DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))) + viewLifecycleOwner.lifecycleScope.launchWhenStarted { + binding.inputAddress.setText(DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))) + } } var initialCount: Int = 0 diff --git a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/send/SendFragment.kt b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/send/SendFragment.kt index 6dd5803c..2ea445ab 100644 --- a/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/send/SendFragment.kt +++ b/demo-app/src/main/java/cash/z/ecc/android/sdk/demoapp/demos/send/SendFragment.kt @@ -32,6 +32,7 @@ import cash.z.ecc.android.sdk.internal.twig import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.type.WalletBalance import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.runBlocking /** * Demonstrates sending funds to an address. This is the most complex example that puts all of the @@ -63,13 +64,13 @@ class SendFragment : BaseDemoFragment() { // have the seed stored val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed() - Initializer(requireApplicationContext()) { - it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) + runBlocking { Initializer.new(requireApplicationContext()) { + runBlocking { it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) } it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext())) - }.let { initializer -> + }}.let { initializer -> synchronizer = Synchronizer(initializer) } - spendingKey = DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() + spendingKey = runBlocking { DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() } } // diff --git a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/AssetTest.kt b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/AssetTest.kt index 43b23fd4..03d17ed3 100644 --- a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/AssetTest.kt +++ b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/AssetTest.kt @@ -5,6 +5,7 @@ import androidx.test.core.app.ApplicationProvider import androidx.test.filters.SmallTest import cash.z.ecc.android.sdk.tool.WalletBirthdayTool import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.runBlocking import org.json.JSONObject import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse @@ -92,9 +93,9 @@ class AssetTest { private data class JsonFile(val jsonObject: JSONObject, val filename: String) companion object { - fun listAssets(network: ZcashNetwork) = WalletBirthdayTool.listBirthdayDirectoryContents( + fun listAssets(network: ZcashNetwork) = runBlocking { WalletBirthdayTool.listBirthdayDirectoryContents( ApplicationProvider.getApplicationContext(), - WalletBirthdayTool.birthdayDirectory(network) - ) + WalletBirthdayTool.birthdayDirectory(network)) + } } } diff --git a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/ext/TestExtensions.kt b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/ext/TestExtensions.kt index b3db1eb3..44c52613 100644 --- a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/ext/TestExtensions.kt +++ b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/ext/TestExtensions.kt @@ -3,13 +3,14 @@ package cash.z.ecc.android.sdk.ext import cash.z.ecc.android.sdk.Initializer import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.util.SimpleMnemonics +import kotlinx.coroutines.runBlocking import okhttp3.OkHttpClient import okhttp3.Request import org.json.JSONObject import ru.gildor.coroutines.okhttp.await fun Initializer.Config.seedPhrase(seedPhrase: String, network: ZcashNetwork) { - setSeed(SimpleMnemonics().toSeed(seedPhrase.toCharArray()), network) + runBlocking { setSeed(SimpleMnemonics().toSeed(seedPhrase.toCharArray()), network) } } object BlockExplorer { diff --git a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/integration/TestnetIntegrationTest.kt b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/integration/TestnetIntegrationTest.kt index 3b269d20..0d7f65c6 100644 --- a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/integration/TestnetIntegrationTest.kt +++ b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/integration/TestnetIntegrationTest.kt @@ -46,7 +46,13 @@ class TestnetIntegrationTest : ScopedTest() { @Test fun testLoadBirthday() { - val (height, hash, time, tree) = WalletBirthdayTool.loadNearest(context, synchronizer.network, saplingActivation + 1) + val (height, hash, time, tree) = runBlocking { + WalletBirthdayTool.loadNearest( + context, + synchronizer.network, + saplingActivation + 1 + ) + } assertEquals(saplingActivation, height) } @@ -118,9 +124,11 @@ class TestnetIntegrationTest : ScopedTest() { val toAddress = "zs1vp7kvlqr4n9gpehztr76lcn6skkss9p8keqs3nv8avkdtjrcctrvmk9a7u494kluv756jeee5k0" private val context = InstrumentationRegistry.getInstrumentation().context - private val initializer = Initializer(context) { config -> - config.setNetwork(ZcashNetwork.Testnet, host) - config.importWallet(seed, birthdayHeight, ZcashNetwork.Testnet) + private val initializer = runBlocking { + Initializer.new(context) { config -> + config.setNetwork(ZcashNetwork.Testnet, host) + runBlocking { config.importWallet(seed, birthdayHeight, ZcashNetwork.Testnet) } + } } private lateinit var synchronizer: Synchronizer diff --git a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/jni/BranchIdTest.kt b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/jni/BranchIdTest.kt index 31355603..eb37fe50 100644 --- a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/jni/BranchIdTest.kt +++ b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/jni/BranchIdTest.kt @@ -3,6 +3,7 @@ package cash.z.ecc.android.sdk.jni import cash.z.ecc.android.sdk.annotation.MaintainedTest import cash.z.ecc.android.sdk.annotation.TestPurpose import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.runBlocking import org.junit.Assert.assertEquals import org.junit.Test import org.junit.runner.RunWith @@ -43,8 +44,8 @@ class BranchIdTest( // is an abnormal use of the SDK because this really should run at the rust level // However, due to quirks on certain devices, we created this test at the Android level, // as a sanity check - val testnetBackend = RustBackend.init("", "", "", ZcashNetwork.Testnet) - val mainnetBackend = RustBackend.init("", "", "", ZcashNetwork.Mainnet) + val testnetBackend = runBlocking { RustBackend.init("", "", "", ZcashNetwork.Testnet) } + val mainnetBackend = runBlocking { RustBackend.init("", "", "", ZcashNetwork.Mainnet) } return listOf( // Mainnet Cases arrayOf("Sapling", 419_200, 1991772603L, "76b809bb", mainnetBackend), diff --git a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/jni/TransparentTest.kt b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/jni/TransparentTest.kt index ab36e288..559de7d1 100644 --- a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/jni/TransparentTest.kt +++ b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/jni/TransparentTest.kt @@ -9,6 +9,7 @@ import cash.z.ecc.android.sdk.internal.TroubleshootingTwig import cash.z.ecc.android.sdk.internal.Twig import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.runBlocking import org.junit.Assert.assertEquals import org.junit.BeforeClass import org.junit.Test @@ -21,23 +22,23 @@ import org.junit.runners.Parameterized class TransparentTest(val expected: Expected, val network: ZcashNetwork) { @Test - fun deriveTransparentSecretKeyTest() { + fun deriveTransparentSecretKeyTest() = runBlocking { assertEquals(expected.tskCompressed, DerivationTool.deriveTransparentSecretKey(SEED, network = network)) } @Test - fun deriveTransparentAddressTest() { + fun deriveTransparentAddressTest() = runBlocking { assertEquals(expected.tAddr, DerivationTool.deriveTransparentAddress(SEED, network = network)) } @Test - fun deriveTransparentAddressFromSecretKeyTest() { + fun deriveTransparentAddressFromSecretKeyTest() = runBlocking { val pk = DerivationTool.deriveTransparentSecretKey(SEED, network = network) assertEquals(expected.tAddr, DerivationTool.deriveTransparentAddressFromPrivateKey(pk, network = network)) } @Test - fun deriveUnifiedViewingKeysFromSeedTest() { + fun deriveUnifiedViewingKeysFromSeedTest() = runBlocking { val uvks = DerivationTool.deriveUnifiedViewingKeys(SEED, network = network) assertEquals(1, uvks.size) val uvk = uvks.first() diff --git a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/tool/WalletBirthdayToolTest.kt b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/tool/WalletBirthdayToolTest.kt index c6eecc37..c6a9b07f 100644 --- a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/tool/WalletBirthdayToolTest.kt +++ b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/tool/WalletBirthdayToolTest.kt @@ -4,7 +4,8 @@ import android.content.Context import androidx.test.core.app.ApplicationProvider import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.filters.SmallTest -import cash.z.ecc.android.sdk.tool.WalletBirthdayTool.Companion.IS_FALLBACK_ON_FAILURE +import cash.z.ecc.android.sdk.tool.WalletBirthdayTool.IS_FALLBACK_ON_FAILURE +import kotlinx.coroutines.runBlocking import org.junit.Assert.assertEquals import org.junit.Test import org.junit.runner.RunWith @@ -25,11 +26,13 @@ class WalletBirthdayToolTest { val directory = "saplingtree/goodnet" val context = ApplicationProvider.getApplicationContext() - val birthday = WalletBirthdayTool.getFirstValidWalletBirthday( - context, - directory, - listOf("1300000.json", "1290000.json") - ) + val birthday = runBlocking { + WalletBirthdayTool.getFirstValidWalletBirthday( + context, + directory, + listOf("1300000.json", "1290000.json") + ) + } assertEquals(1300000, birthday.height) } @@ -42,11 +45,13 @@ class WalletBirthdayToolTest { val directory = "saplingtree/badnet" val context = ApplicationProvider.getApplicationContext() - val birthday = WalletBirthdayTool.getFirstValidWalletBirthday( - context, - directory, - listOf("1300000.json", "1290000.json") - ) + val birthday = runBlocking { + WalletBirthdayTool.getFirstValidWalletBirthday( + context, + directory, + listOf("1300000.json", "1290000.json") + ) + } assertEquals(1290000, birthday.height) } } diff --git a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/BalancePrinterUtil.kt b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/BalancePrinterUtil.kt index 84847522..2499e8c1 100644 --- a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/BalancePrinterUtil.kt +++ b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/BalancePrinterUtil.kt @@ -5,6 +5,7 @@ import cash.z.ecc.android.sdk.Initializer import cash.z.ecc.android.sdk.Synchronizer import cash.z.ecc.android.sdk.internal.TroubleshootingTwig import cash.z.ecc.android.sdk.internal.Twig +import cash.z.ecc.android.sdk.internal.ext.deleteSuspend import cash.z.ecc.android.sdk.internal.twig import cash.z.ecc.android.sdk.tool.WalletBirthdayTool import cash.z.ecc.android.sdk.type.WalletBirthday @@ -52,7 +53,7 @@ class BalancePrinterUtil { fun setup() { Twig.plant(TroubleshootingTwig()) cacheBlocks() - birthday = WalletBirthdayTool.loadNearest(context, network, birthdayHeight) + birthday = runBlocking { WalletBirthdayTool.loadNearest(context, network, birthdayHeight) } } private fun cacheBlocks() = runBlocking { @@ -66,8 +67,8 @@ class BalancePrinterUtil { // assertEquals(-1, error) } - private fun deleteDb(dbName: String) { - context.getDatabasePath(dbName).absoluteFile.delete() + private suspend fun deleteDb(dbName: String) { + context.getDatabasePath(dbName).absoluteFile.deleteSuspend() } @Test @@ -79,8 +80,8 @@ class BalancePrinterUtil { mnemonics.toSeed(seedPhrase.toCharArray()) }.collect { seed -> // TODO: clear the dataDb but leave the cacheDb - val initializer = Initializer(context) { config -> - config.importWallet(seed, birthdayHeight, network) + val initializer = Initializer.new(context) { config -> + runBlocking { config.importWallet(seed, birthdayHeight, network) } config.setNetwork(network) config.alias = alias } diff --git a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/DataDbScannerUtil.kt b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/DataDbScannerUtil.kt index f978947c..248b92a8 100644 --- a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/DataDbScannerUtil.kt +++ b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/DataDbScannerUtil.kt @@ -64,7 +64,13 @@ class DataDbScannerUtil { @Test @Ignore("This test is broken") fun scanExistingDb() { - synchronizer = Synchronizer(Initializer(context) { it.setBirthdayHeight(birthdayHeight) }) + synchronizer = Synchronizer(runBlocking { + Initializer.new(context) { + it.setBirthdayHeight( + birthdayHeight + ) + } + }) println("sync!") synchronizer.start() diff --git a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/TestWallet.kt b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/TestWallet.kt index 45bf2279..03692708 100644 --- a/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/TestWallet.kt +++ b/sdk-lib/src/androidTest/java/cash/z/ecc/android/sdk/util/TestWallet.kt @@ -23,6 +23,7 @@ import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.takeWhile import kotlinx.coroutines.launch import kotlinx.coroutines.newFixedThreadPoolContext +import kotlinx.coroutines.runBlocking import java.util.concurrent.TimeoutException /** @@ -51,19 +52,29 @@ class TestWallet( val walletScope = CoroutineScope( SupervisorJob() + newFixedThreadPoolContext(3, this.javaClass.simpleName) ) + + // Although runBlocking isn't great, this usage is OK because this is only used within the + // automated tests + private val context = InstrumentationRegistry.getInstrumentation().context private val seed: ByteArray = Mnemonics.MnemonicCode(seedPhrase).toSeed() - private val shieldedSpendingKey = DerivationTool.deriveSpendingKeys(seed, network = network)[0] - private val transparentSecretKey = DerivationTool.deriveTransparentSecretKey(seed, network = network) - val initializer = Initializer(context) { config -> - config.importWallet(seed, startHeight, network, host, alias = alias) + private val shieldedSpendingKey = + runBlocking { DerivationTool.deriveSpendingKeys(seed, network = network)[0] } + private val transparentSecretKey = + runBlocking { DerivationTool.deriveTransparentSecretKey(seed, network = network) } + val initializer = runBlocking { + Initializer.new(context) { config -> + runBlocking { config.importWallet(seed, startHeight, network, host, alias = alias) } + } } val synchronizer: SdkSynchronizer = Synchronizer(initializer) as SdkSynchronizer val service = (synchronizer.processor.downloader.lightWalletService as LightWalletGrpcService) val available get() = synchronizer.saplingBalances.value.availableZatoshi - val shieldedAddress = DerivationTool.deriveShieldedAddress(seed, network = network) - val transparentAddress = DerivationTool.deriveTransparentAddress(seed, network = network) + val shieldedAddress = + runBlocking { DerivationTool.deriveShieldedAddress(seed, network = network) } + val transparentAddress = + runBlocking { DerivationTool.deriveTransparentAddress(seed, network = network) } val birthdayHeight get() = synchronizer.latestBirthdayHeight val networkName get() = synchronizer.network.networkName val connectionInfo get() = service.connectionInfo.toString() diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/Initializer.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/Initializer.kt index b27a34ee..f08a2ef5 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/Initializer.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/Initializer.kt @@ -4,95 +4,37 @@ import android.content.Context import cash.z.ecc.android.sdk.exception.InitializerException import cash.z.ecc.android.sdk.ext.ZcashSdk import cash.z.ecc.android.sdk.internal.twig +import cash.z.ecc.android.sdk.internal.SdkDispatchers +import cash.z.ecc.android.sdk.internal.ext.getCacheDirSuspend +import cash.z.ecc.android.sdk.internal.ext.getDatabasePathSuspend import cash.z.ecc.android.sdk.jni.RustBackend import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.tool.WalletBirthdayTool import cash.z.ecc.android.sdk.type.UnifiedViewingKey import cash.z.ecc.android.sdk.type.WalletBirthday import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext import java.io.File /** * Simplified Initializer focused on starting from a ViewingKey. */ -class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Throwable?) -> Boolean)? = null, config: Config) { - val context = appContext.applicationContext - val rustBackend: RustBackend - val network: ZcashNetwork - val alias: String - val host: String - val port: Int - val viewingKeys: List - val overwriteVks: Boolean +class Initializer private constructor( + val context: Context, + val rustBackend: RustBackend, + val network: ZcashNetwork, + val alias: String, + val host: String, + val port: Int, + val viewingKeys: List, + val overwriteVks: Boolean, val birthday: WalletBirthday +) { - /** - * A callback to invoke whenever an uncaught error is encountered. By definition, the return - * value of the function is ignored because this error is unrecoverable. The only reason the - * function has a return value is so that all error handlers work with the same signature which - * allows one function to handle all errors in simple apps. - */ - var onCriticalErrorHandler: ((Throwable?) -> Boolean)? = onCriticalErrorHandler + suspend fun erase() = erase(context, network, alias) - init { - try { - config.validate() - network = config.network - val heightToUse = config.birthdayHeight - ?: (if (config.defaultToOldestHeight == true) network.saplingActivationHeight else null) - val loadedBirthday = WalletBirthdayTool.loadNearest(context, network, heightToUse) - birthday = loadedBirthday - viewingKeys = config.viewingKeys - overwriteVks = config.overwriteVks - alias = config.alias - host = config.host - port = config.port - rustBackend = initRustBackend(network, birthday) - } catch (t: Throwable) { - onCriticalError(t) - throw t - } - } - - constructor(appContext: Context, config: Config) : this(appContext, null, config) - constructor(appContext: Context, onCriticalErrorHandler: ((Throwable?) -> Boolean)? = null, block: (Config) -> Unit) : this(appContext, onCriticalErrorHandler, Config(block)) - - fun erase() = erase(context, network, alias) - - private fun initRustBackend(network: ZcashNetwork, birthday: WalletBirthday): RustBackend { - return RustBackend.init( - cacheDbPath(context, network, alias), - dataDbPath(context, network, alias), - "${context.cacheDir.absolutePath}/params", - network, - birthday.height - ) - } - - private fun onCriticalError(error: Throwable) { - twig("********") - twig("******** INITIALIZER ERROR: $error") - if (error.cause != null) twig("******** caused by ${error.cause}") - if (error.cause?.cause != null) twig("******** caused by ${error.cause?.cause}") - twig("********") - twig(error) - - if (onCriticalErrorHandler == null) { - twig( - "WARNING: a critical error occurred on the Initializer but no callback is " + - "registered to be notified of critical errors! THIS IS PROBABLY A MISTAKE. To " + - "respond to these errors (perhaps to update the UI or alert the user) set " + - "initializer.onCriticalErrorHandler to a non-null value or use the secondary " + - "constructor: Initializer(context, handler) { ... }. Note that the synchronizer " + - "and initializer BOTH have error handlers and since the initializer exists " + - "before the synchronizer, it needs its error handler set separately." - ) - } - - onCriticalErrorHandler?.invoke(error) - } - - class Config private constructor ( + class Config private constructor( val viewingKeys: MutableList = mutableListOf(), var alias: String = ZcashSdk.DEFAULT_ALIAS, ) { @@ -177,7 +119,10 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr * is not currently well supported. Consider it an alpha-preview feature that might work but * probably has serious bugs. */ - fun setViewingKeys(vararg unifiedViewingKeys: UnifiedViewingKey, overwrite: Boolean = false): Config = apply { + fun setViewingKeys( + vararg unifiedViewingKeys: UnifiedViewingKey, + overwrite: Boolean = false + ): Config = apply { overwriteVks = overwrite viewingKeys.apply { clear() @@ -225,7 +170,7 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr /** * Import a wallet using the first viewing key derived from the given seed. */ - fun importWallet( + suspend fun importWallet( seed: ByteArray, birthdayHeight: Int? = null, network: ZcashNetwork, @@ -262,7 +207,7 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr /** * Create a new wallet using the first viewing key derived from the given seed. */ - fun newWallet( + suspend fun newWallet( seed: ByteArray, network: ZcashNetwork, host: String = network.defaultHost, @@ -296,9 +241,20 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr * Convenience method for setting thew viewingKeys from a given seed. This is the same as * calling `setViewingKeys` with the keys that match this seed. */ - fun setSeed(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int = 1): Config = apply { - setViewingKeys(*DerivationTool.deriveUnifiedViewingKeys(seed, network, numberOfAccounts)) - } + suspend fun setSeed( + seed: ByteArray, + network: ZcashNetwork, + numberOfAccounts: Int = 1 + ): Config = + apply { + setViewingKeys( + *DerivationTool.deriveUnifiedViewingKeys( + seed, + network, + numberOfAccounts + ) + ) + } /** * Sets the network from a network id, throwing an exception if the id is not recognized. @@ -338,16 +294,89 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr private fun validateViewingKeys() { require(viewingKeys.isNotEmpty()) { "Unified Viewing keys are required. Ensure that the unified viewing keys or seed" + - " have been set on this Initializer." + " have been set on this Initializer." } viewingKeys.forEach { DerivationTool.validateUnifiedViewingKey(it) } } + } companion object : SdkSynchronizer.Erasable { + suspend fun new(appContext: Context, config: Config) = new(appContext, null, config) + + suspend fun new( + appContext: Context, + onCriticalErrorHandler: ((Throwable?) -> Boolean)? = null, + block: (Config) -> Unit + ) = new(appContext, onCriticalErrorHandler, Config(block)) + + suspend fun new( + context: Context, + onCriticalErrorHandler: ((Throwable?) -> Boolean)?, + config: Config + ): Initializer { + config.validate() + val heightToUse = config.birthdayHeight + ?: (if (config.defaultToOldestHeight == true) config.network.saplingActivationHeight else null) + val loadedBirthday = + WalletBirthdayTool.loadNearest(context, config.network, heightToUse) + + val rustBackend = initRustBackend(context, config.network, config.alias, loadedBirthday) + + return Initializer( + context.applicationContext, + rustBackend, + config.network, + config.alias, + config.host, + config.port, + config.viewingKeys, + config.overwriteVks, + loadedBirthday + ) + } + + private fun onCriticalError(onCriticalErrorHandler: ((Throwable?) -> Boolean)?, error: Throwable) { + twig("********") + twig("******** INITIALIZER ERROR: $error") + if (error.cause != null) twig("******** caused by ${error.cause}") + if (error.cause?.cause != null) twig("******** caused by ${error.cause?.cause}") + twig("********") + twig(error) + + if (onCriticalErrorHandler == null) { + twig( + "WARNING: a critical error occurred on the Initializer but no callback is " + + "registered to be notified of critical errors! THIS IS PROBABLY A MISTAKE. To " + + "respond to these errors (perhaps to update the UI or alert the user) set " + + "initializer.onCriticalErrorHandler to a non-null value or use the secondary " + + "constructor: Initializer(context, handler) { ... }. Note that the synchronizer " + + "and initializer BOTH have error handlers and since the initializer exists " + + "before the synchronizer, it needs its error handler set separately." + ) + } + + onCriticalErrorHandler?.invoke(error) + } + + private suspend fun initRustBackend( + context: Context, + network: ZcashNetwork, + alias: String, + birthday: WalletBirthday + ): RustBackend { + return RustBackend.init( + cacheDbPath(context, network, alias), + dataDbPath(context, network, alias), + File(context.getCacheDirSuspend(), "params").absolutePath, + network, + birthday.height + ) + } + /** * Delete the databases associated with this wallet. This removes all compact blocks and * data derived from those blocks. For most wallets, this should not result in a loss of @@ -362,7 +391,11 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr * @return true when one of the associated files was found. False most likely indicates * that the wrong alias was provided. */ - override fun erase(appContext: Context, network: ZcashNetwork, alias: String): Boolean { + override suspend fun erase( + appContext: Context, + network: ZcashNetwork, + alias: String + ): Boolean { val cacheDeleted = deleteDb(cacheDbPath(appContext, network, alias)) val dataDeleted = deleteDb(dataDbPath(appContext, network, alias)) return dataDeleted || cacheDeleted @@ -379,7 +412,11 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr * @param network the network associated with the data in the cache database. * @param alias the alias to convert into a database path */ - internal fun cacheDbPath(appContext: Context, network: ZcashNetwork, alias: String): String = + private suspend fun cacheDbPath( + appContext: Context, + network: ZcashNetwork, + alias: String + ): String = aliasToPath(appContext, network, alias, ZcashSdk.DB_CACHE_NAME) /** @@ -388,12 +425,21 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr * * @param network the network associated with the data in the database. * @param alias the alias to convert into a database path */ - internal fun dataDbPath(appContext: Context, network: ZcashNetwork, alias: String): String = + private suspend fun dataDbPath( + appContext: Context, + network: ZcashNetwork, + alias: String + ): String = aliasToPath(appContext, network, alias, ZcashSdk.DB_DATA_NAME) - private fun aliasToPath(appContext: Context, network: ZcashNetwork, alias: String, dbFileName: String): String { + private suspend fun aliasToPath( + appContext: Context, + network: ZcashNetwork, + alias: String, + dbFileName: String + ): String { val parentDir: String = - appContext.getDatabasePath("unused.db").parentFile?.absolutePath + appContext.getDatabasePathSuspend("unused.db").parentFile?.absolutePath ?: throw InitializerException.DatabasePathException val prefix = if (alias.endsWith('_')) alias else "${alias}_" return File(parentDir, "$prefix${network.networkName}_$dbFileName").absolutePath @@ -405,9 +451,10 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr * @param filePath the path of the db to erase. * @return true when a file exists at the given path and was deleted. */ - private fun deleteDb(filePath: String): Boolean { + private suspend fun deleteDb(filePath: String): Boolean { // just try the journal file. Doesn't matter if it's not there. delete("$filePath-journal") + return delete(filePath) } @@ -417,14 +464,16 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr * @param filePath the path of the file to erase. * @return true when a file exists at the given path and was deleted. */ - private fun delete(filePath: String): Boolean { + private suspend fun delete(filePath: String): Boolean { return File(filePath).let { - if (it.exists()) { - twig("Deleting ${it.name}!") - it.delete() - true - } else { - false + withContext(SdkDispatchers.IO) { + if (it.exists()) { + twig("Deleting ${it.name}!") + it.delete() + true + } else { + false + } } } } @@ -445,9 +494,9 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr internal fun validateAlias(alias: String) { require( alias.length in 1..99 && alias[0].isLetter() && - alias.all { it.isLetterOrDigit() || it == '_' } + alias.all { it.isLetterOrDigit() || it == '_' } ) { "ERROR: Invalid alias ($alias). For security, the alias must be shorter than 100 " + - "characters and only contain letters, digits or underscores and start with a letter." + "characters and only contain letters, digits or underscores and start with a letter." } } diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/SdkSynchronizer.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/SdkSynchronizer.kt index bf996193..387333fd 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/SdkSynchronizer.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/SdkSynchronizer.kt @@ -58,7 +58,6 @@ import io.grpc.ManagedChannel import kotlinx.coroutines.CoroutineExceptionHandler import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.Dispatchers.IO import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.Job @@ -247,7 +246,7 @@ class SdkSynchronizer internal constructor( override val latestBirthdayHeight: Int get() = processor.birthdayHeight - override fun prepare(): Synchronizer = apply { + override suspend fun prepare(): Synchronizer = apply { storage.prepare() } @@ -336,15 +335,15 @@ class SdkSynchronizer internal constructor( // TODO: turn this section into the data access API. For now, just aggregate all the things that we want to do with the underlying data - fun findBlockHash(height: Int): ByteArray? { + suspend fun findBlockHash(height: Int): ByteArray? { return (storage as? PagedTransactionRepository)?.findBlockHash(height) } - fun findBlockHashAsHex(height: Int): String? { + suspend fun findBlockHashAsHex(height: Int): String? { return findBlockHash(height)?.toHexReversed() } - fun getTransactionCount(): Int { + suspend fun getTransactionCount(): Int { return (storage as? PagedTransactionRepository)?.getTransactionCount() ?: 0 } @@ -530,7 +529,7 @@ class SdkSynchronizer internal constructor( } } - private suspend fun refreshPendingTransactions() = withContext(IO) { + private suspend fun refreshPendingTransactions() = withContext(Dispatchers.IO) { twig("[cleanup] beginning to refresh and clean up pending transactions") // TODO: this would be the place to clear out any stale pending transactions. Remove filter // logic and then delete any pending transaction with sufficient confirmations (all in one @@ -737,7 +736,7 @@ class SdkSynchronizer internal constructor( * * @return true when content was found for the given alias. False otherwise. */ - fun erase(appContext: Context, network: ZcashNetwork, alias: String = ZcashSdk.DEFAULT_ALIAS): Boolean + suspend fun erase(appContext: Context, network: ZcashNetwork, alias: String = ZcashSdk.DEFAULT_ALIAS): Boolean } } diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/Synchronizer.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/Synchronizer.kt index 3c18c8b8..0094ccc0 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/Synchronizer.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/Synchronizer.kt @@ -35,7 +35,7 @@ interface Synchronizer { * where setup and maintenance can occur for various Synchronizers. One that uses a database * would take this opportunity to do data migrations or key migrations. */ - fun prepare(): Synchronizer + suspend fun prepare(): Synchronizer /** * Starts this synchronizer within the given scope. diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/block/CompactBlockProcessor.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/block/CompactBlockProcessor.kt index 8ecc7d6e..9e0dbb10 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/block/CompactBlockProcessor.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/block/CompactBlockProcessor.kt @@ -526,7 +526,7 @@ class CompactBlockProcessor( * @return [ERROR_CODE_NONE] when there is no problem. Otherwise, return the lowest height where an error was * found. In other words, validation starts at the back of the chain and works toward the tip. */ - private fun validateNewBlocks(range: IntRange?): Int { + private suspend fun validateNewBlocks(range: IntRange?): Int { if (range?.isEmpty() != false) { twig("no blocks to validate: $range") return ERROR_CODE_NONE diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/SdkDispatchers.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/SdkDispatchers.kt new file mode 100644 index 00000000..20cac798 --- /dev/null +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/SdkDispatchers.kt @@ -0,0 +1,14 @@ +package cash.z.ecc.android.sdk.internal + +import kotlinx.coroutines.asCoroutineDispatcher +import java.util.concurrent.Executors + +internal object SdkDispatchers { + /* + * Based on internal discussion, keep the SDK internals confined to a single IO thread. + * + * We don't expect things to break, but we don't have the WAL enabled for SQLite so this + * is a simple solution. + */ + val IO = Executors.newSingleThreadExecutor().asCoroutineDispatcher() +} \ No newline at end of file diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockDbStore.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockDbStore.kt index 0373d3a6..4c7fedcb 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockDbStore.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockDbStore.kt @@ -6,8 +6,9 @@ import androidx.room.RoomDatabase import cash.z.ecc.android.sdk.internal.db.CompactBlockDao import cash.z.ecc.android.sdk.internal.db.CompactBlockDb import cash.z.ecc.android.sdk.db.entity.CompactBlockEntity +import cash.z.ecc.android.sdk.internal.SdkDispatchers import cash.z.wallet.sdk.rpc.CompactFormats -import kotlinx.coroutines.Dispatchers.IO +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext /** @@ -38,7 +39,7 @@ class CompactBlockDbStore( .build() } - override suspend fun getLatestHeight(): Int = withContext(IO) { + override suspend fun getLatestHeight(): Int = withContext(SdkDispatchers.IO) { Math.max(0, cacheDao.latestBlockHeight()) } @@ -46,15 +47,17 @@ class CompactBlockDbStore( return cacheDao.findCompactBlock(height)?.let { CompactFormats.CompactBlock.parseFrom(it) } } - override suspend fun write(result: List) = withContext(IO) { + override suspend fun write(result: List) = withContext(SdkDispatchers.IO) { cacheDao.insert(result.map { CompactBlockEntity(it.height.toInt(), it.toByteArray()) }) } - override suspend fun rewindTo(height: Int) = withContext(IO) { + override suspend fun rewindTo(height: Int) = withContext(SdkDispatchers.IO) { cacheDao.rewindTo(height) } - override fun close() { - cacheDb.close() + override suspend fun close() { + withContext(SdkDispatchers.IO) { + cacheDb.close() + } } } diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockDownloader.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockDownloader.kt index c808761e..56013e7b 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockDownloader.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockDownloader.kt @@ -7,6 +7,7 @@ import cash.z.ecc.android.sdk.internal.service.LightWalletService import cash.z.wallet.sdk.rpc.Service import io.grpc.StatusRuntimeException import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers.IO import kotlinx.coroutines.delay import kotlinx.coroutines.launch @@ -122,8 +123,10 @@ open class CompactBlockDownloader private constructor(val compactBlockStore: Com /** * Stop this downloader and cleanup any resources being used. */ - fun stop() { - lightWalletService.shutdown() + suspend fun stop() { + withContext(Dispatchers.IO) { + lightWalletService.shutdown() + } compactBlockStore.close() } diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockStore.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockStore.kt index dfdc7ed3..93058875 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockStore.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/block/CompactBlockStore.kt @@ -37,5 +37,5 @@ interface CompactBlockStore { /** * Close any connections to the block store. */ - fun close() + suspend fun close() } diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/ext/ContextExt.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/ext/ContextExt.kt new file mode 100644 index 00000000..8e702e2f --- /dev/null +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/ext/ContextExt.kt @@ -0,0 +1,11 @@ +package cash.z.ecc.android.sdk.internal.ext + +import android.content.Context +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext + +suspend fun Context.getDatabasePathSuspend(fileName: String) = + withContext(Dispatchers.IO) { getDatabasePath(fileName) } + +suspend fun Context.getCacheDirSuspend() = + withContext(Dispatchers.IO) { cacheDir } \ No newline at end of file diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/ext/FileExt.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/ext/FileExt.kt new file mode 100644 index 00000000..126641a3 --- /dev/null +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/ext/FileExt.kt @@ -0,0 +1,7 @@ +package cash.z.ecc.android.sdk.internal.ext + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import java.io.File + +suspend fun File.deleteSuspend() = withContext(Dispatchers.IO) { delete() } \ No newline at end of file diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/PagedTransactionRepository.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/PagedTransactionRepository.kt index 91ed7ad9..d9337afb 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/PagedTransactionRepository.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/PagedTransactionRepository.kt @@ -16,10 +16,12 @@ import cash.z.ecc.android.sdk.internal.ext.android.toFlowPagedList import cash.z.ecc.android.sdk.internal.ext.android.toRefreshable import cash.z.ecc.android.sdk.internal.ext.tryWarn import cash.z.ecc.android.sdk.internal.twig +import cash.z.ecc.android.sdk.internal.SdkDispatchers import cash.z.ecc.android.sdk.jni.RustBackend import cash.z.ecc.android.sdk.type.UnifiedAddressAccount import cash.z.ecc.android.sdk.type.UnifiedViewingKey import cash.z.ecc.android.sdk.type.WalletBirthday +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers.IO import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.emitAll @@ -95,7 +97,7 @@ class PagedTransactionRepository( override suspend fun getAccountCount(): Int = lazy.accounts.count() - override fun prepare() { + override suspend fun prepare() { if (lazy.isPrepared.get()) { twig("Warning: skipped the preparation step because we're already prepared!") } else { @@ -112,7 +114,7 @@ class PagedTransactionRepository( * side because Rust was intended to own the "dataDb" and Kotlin just reads from it. Since then, * it has been more clear that Kotlin should own the data and just let Rust use it. */ - private fun initMissingDatabases() { + private suspend fun initMissingDatabases() { maybeCreateDataDb() maybeInitBlocksTable(birthday) maybeInitAccountsTable(viewingKeys) @@ -121,7 +123,7 @@ class PagedTransactionRepository( /** * Create the dataDb and its table, if it doesn't exist. */ - private fun maybeCreateDataDb() { + private suspend fun maybeCreateDataDb() { tryWarn("Warning: did not create dataDb. It probably already exists.") { rustBackend.initDataDb() twig("Initialized wallet for first run file: ${rustBackend.pathDataDb}") @@ -131,7 +133,7 @@ class PagedTransactionRepository( /** * Initialize the blocks table with the given birthday, if needed. */ - private fun maybeInitBlocksTable(birthday: WalletBirthday) { + private suspend fun maybeInitBlocksTable(birthday: WalletBirthday) { // TODO: consider converting these to typed exceptions in the welding layer tryWarn( "Warning: did not initialize the blocks table. It probably was already initialized.", @@ -151,7 +153,7 @@ class PagedTransactionRepository( /** * Initialize the accounts table with the given viewing keys. */ - private fun maybeInitAccountsTable(viewingKeys: List) { + private suspend fun maybeInitAccountsTable(viewingKeys: List) { // TODO: consider converting these to typed exceptions in the welding layer tryWarn( "Warning: did not initialize the accounts table. It probably was already initialized.", @@ -181,7 +183,7 @@ class PagedTransactionRepository( } } - private fun applyKeyMigrations() { + private suspend fun applyKeyMigrations() { if (overwriteVks) { twig("applying key migrations . . .") maybeInitAccountsTable(viewingKeys) @@ -191,8 +193,10 @@ class PagedTransactionRepository( /** * Close the underlying database. */ - fun close() { - lazy.db?.close() + suspend fun close() { + withContext(Dispatchers.IO) { + lazy.db?.close() + } } // TODO: begin converting these into Data Access API. For now, just collect the desired operations and iterate/refactor, later diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/TransactionRepository.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/TransactionRepository.kt index 9b18113c..dd79b258 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/TransactionRepository.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/TransactionRepository.kt @@ -87,7 +87,7 @@ interface TransactionRepository { suspend fun getAccountCount(): Int - fun prepare() + suspend fun prepare() // // Transactions diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/WalletTransactionEncoder.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/WalletTransactionEncoder.kt index c5d7a2b3..b950a91b 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/WalletTransactionEncoder.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/transaction/WalletTransactionEncoder.kt @@ -8,7 +8,8 @@ import cash.z.ecc.android.sdk.internal.twigTask import cash.z.ecc.android.sdk.jni.RustBackend import cash.z.ecc.android.sdk.jni.RustBackendWelding import cash.z.ecc.android.sdk.internal.SaplingParamTool -import kotlinx.coroutines.Dispatchers.IO +import cash.z.ecc.android.sdk.internal.SdkDispatchers +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext /** @@ -44,7 +45,7 @@ class WalletTransactionEncoder( toAddress: String, memo: ByteArray?, fromAccountIndex: Int - ): EncodedTransaction = withContext(IO) { + ): EncodedTransaction = withContext(SdkDispatchers.IO) { val transactionId = createSpend(spendingKey, zatoshi, toAddress, memo) repository.findEncodedTransactionById(transactionId) ?: throw TransactionEncoderException.TransactionNotFoundException(transactionId) @@ -54,7 +55,7 @@ class WalletTransactionEncoder( spendingKey: String, transparentSecretKey: String, memo: ByteArray? - ): EncodedTransaction = withContext(IO) { + ): EncodedTransaction = withContext(SdkDispatchers.IO) { val transactionId = createShieldingSpend(spendingKey, transparentSecretKey, memo) repository.findEncodedTransactionById(transactionId) ?: throw TransactionEncoderException.TransactionNotFoundException(transactionId) @@ -68,7 +69,7 @@ class WalletTransactionEncoder( * * @return true when the given address is a valid z-addr */ - override suspend fun isValidShieldedAddress(address: String): Boolean = withContext(IO) { + override suspend fun isValidShieldedAddress(address: String): Boolean = withContext(Dispatchers.IO) { rustBackend.isValidShieldedAddr(address) } @@ -80,7 +81,7 @@ class WalletTransactionEncoder( * * @return true when the given address is a valid t-addr */ - override suspend fun isValidTransparentAddress(address: String): Boolean = withContext(IO) { + override suspend fun isValidTransparentAddress(address: String): Boolean = withContext(Dispatchers.IO) { rustBackend.isValidTransparentAddr(address) } @@ -110,7 +111,7 @@ class WalletTransactionEncoder( toAddress: String, memo: ByteArray? = byteArrayOf(), fromAccountIndex: Int = 0 - ): Long = withContext(IO) { + ): Long = withContext(Dispatchers.IO) { twigTask( "creating transaction to spend $zatoshi zatoshi to" + " ${toAddress.masked()} with memo $memo" @@ -140,7 +141,7 @@ class WalletTransactionEncoder( spendingKey: String, transparentSecretKey: String, memo: ByteArray? = byteArrayOf() - ): Long = withContext(IO) { + ): Long = withContext(Dispatchers.IO) { twigTask("creating transaction to shield all UTXOs") { try { SaplingParamTool.ensureParams((rustBackend as RustBackend).pathParamsDir) diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/NativeLibraryLoader.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/NativeLibraryLoader.kt new file mode 100644 index 00000000..cbd5d12a --- /dev/null +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/NativeLibraryLoader.kt @@ -0,0 +1,44 @@ +package cash.z.ecc.android.sdk.jni + +import cash.z.ecc.android.sdk.internal.twig +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext +import java.util.concurrent.atomic.AtomicBoolean + +/** + * Loads a native library once. This class is thread-safe. + * + * To use this class, create a singleton instance for each given [libraryName]. + * + * @param libraryName Name of the library to load. + */ +internal class NativeLibraryLoader(private val libraryName: String) { + private val isLoaded = AtomicBoolean(false) + private val mutex = Mutex() + + suspend fun load() { + // Double-checked locking to avoid the Mutex unless necessary, as the hot path is + // for the library to be loaded since this should only run once for the lifetime + // of the application + if (!isLoaded.get()) { + mutex.withLock { + if (!isLoaded.get()) { + loadRustLibrary() + } + } + } + } + + private suspend fun loadRustLibrary() { + try { + withContext(Dispatchers.IO) { + twig("Loading native library $libraryName") { System.loadLibrary(libraryName) } + } + isLoaded.set(true) + } catch (e: Throwable) { + twig("Error while loading native library: ${e.message}") + } + } +} \ No newline at end of file diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/RustBackend.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/RustBackend.kt index 35f27233..bd3c2d2d 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/RustBackend.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/RustBackend.kt @@ -4,10 +4,14 @@ import cash.z.ecc.android.sdk.exception.BirthdayException import cash.z.ecc.android.sdk.ext.ZcashSdk.OUTPUT_PARAM_FILE_NAME import cash.z.ecc.android.sdk.ext.ZcashSdk.SPEND_PARAM_FILE_NAME import cash.z.ecc.android.sdk.internal.twig +import cash.z.ecc.android.sdk.internal.SdkDispatchers +import cash.z.ecc.android.sdk.internal.ext.deleteSuspend import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.type.UnifiedViewingKey import cash.z.ecc.android.sdk.type.WalletBalance import cash.z.ecc.android.sdk.type.ZcashNetwork +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext import java.io.File /** @@ -17,10 +21,6 @@ import java.io.File */ class RustBackend private constructor() : RustBackendWelding { - init { - load() - } - // Paths lateinit var pathDataDb: String internal set @@ -35,14 +35,14 @@ class RustBackend private constructor() : RustBackendWelding { get() = if (field != -1) field else throw BirthdayException.UninitializedBirthdayException private set - fun clear(clearCacheDb: Boolean = true, clearDataDb: Boolean = true) { + suspend fun clear(clearCacheDb: Boolean = true, clearDataDb: Boolean = true) { if (clearCacheDb) { twig("Deleting the cache database!") - File(pathCacheDb).delete() + File(pathCacheDb).deleteSuspend() } if (clearDataDb) { twig("Deleting the data database!") - File(pathDataDb).delete() + File(pathDataDb).deleteSuspend() } } @@ -50,19 +50,31 @@ class RustBackend private constructor() : RustBackendWelding { // Wrapper Functions // - override fun initDataDb() = initDataDb(pathDataDb, networkId = network.id) + override suspend fun initDataDb() = withContext(SdkDispatchers.IO) { + initDataDb( + pathDataDb, + networkId = network.id + ) + } - override fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean { + override suspend fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean { val extfvks = Array(keys.size) { "" } val extpubs = Array(keys.size) { "" } keys.forEachIndexed { i, key -> extfvks[i] = key.extfvk extpubs[i] = key.extpub } - return initAccountsTableWithKeys(pathDataDb, extfvks, extpubs, networkId = network.id) + return withContext(SdkDispatchers.IO) { + initAccountsTableWithKeys( + pathDataDb, + extfvks, + extpubs, + networkId = network.id + ) + } } - override fun initAccountsTable( + override suspend fun initAccountsTable( seed: ByteArray, numberOfAccounts: Int ): Array { @@ -71,82 +83,131 @@ class RustBackend private constructor() : RustBackendWelding { } } - override fun initBlocksTable( + override suspend fun initBlocksTable( height: Int, hash: String, time: Long, saplingTree: String ): Boolean { - return initBlocksTable(pathDataDb, height, hash, time, saplingTree, networkId = network.id) + return withContext(SdkDispatchers.IO) { + initBlocksTable( + pathDataDb, + height, + hash, + time, + saplingTree, + networkId = network.id + ) + } } - override fun getShieldedAddress(account: Int) = getShieldedAddress(pathDataDb, account, networkId = network.id) + override suspend fun getShieldedAddress(account: Int) = withContext(SdkDispatchers.IO) { + getShieldedAddress( + pathDataDb, + account, + networkId = network.id + ) + } - override fun getTransparentAddress(account: Int, index: Int): String { + override suspend fun getTransparentAddress(account: Int, index: Int): String { throw NotImplementedError("TODO: implement this at the zcash_client_sqlite level. But for now, use DerivationTool, instead to derive addresses from seeds") } - override fun getBalance(account: Int) = getBalance(pathDataDb, account, networkId = network.id) + override suspend fun getBalance(account: Int) = withContext(SdkDispatchers.IO) { + getBalance( + pathDataDb, + account, + networkId = network.id + ) + } - override fun getVerifiedBalance(account: Int) = getVerifiedBalance(pathDataDb, account, networkId = network.id) + override suspend fun getVerifiedBalance(account: Int) = withContext(SdkDispatchers.IO) { + getVerifiedBalance( + pathDataDb, + account, + networkId = network.id + ) + } - override fun getReceivedMemoAsUtf8(idNote: Long) = - getReceivedMemoAsUtf8(pathDataDb, idNote, networkId = network.id) + override suspend fun getReceivedMemoAsUtf8(idNote: Long) = + withContext(SdkDispatchers.IO) { getReceivedMemoAsUtf8(pathDataDb, idNote, networkId = network.id) } - override fun getSentMemoAsUtf8(idNote: Long) = getSentMemoAsUtf8(pathDataDb, idNote, networkId = network.id) + override suspend fun getSentMemoAsUtf8(idNote: Long) = withContext(SdkDispatchers.IO) { + getSentMemoAsUtf8( + pathDataDb, + idNote, + networkId = network.id + ) + } - override fun validateCombinedChain() = validateCombinedChain(pathCacheDb, pathDataDb, networkId = network.id,) + override suspend fun validateCombinedChain() = withContext(SdkDispatchers.IO) { + validateCombinedChain( + pathCacheDb, + pathDataDb, + networkId = network.id, + ) + } - override fun getNearestRewindHeight(height: Int): Int = getNearestRewindHeight(pathDataDb, height, networkId = network.id) + override suspend fun getNearestRewindHeight(height: Int): Int = withContext(SdkDispatchers.IO) { + getNearestRewindHeight( + pathDataDb, + height, + networkId = network.id + ) + } /** * Deletes data for all blocks above the given height. Boils down to: * * DELETE FROM blocks WHERE height > ? */ - override fun rewindToHeight(height: Int) = rewindToHeight(pathDataDb, height, networkId = network.id) + override suspend fun rewindToHeight(height: Int) = + withContext(SdkDispatchers.IO) { rewindToHeight(pathDataDb, height, networkId = network.id) } - override fun scanBlocks(limit: Int): Boolean { + override suspend fun scanBlocks(limit: Int): Boolean { return if (limit > 0) { - scanBlockBatch(pathCacheDb, pathDataDb, limit, networkId = network.id) + withContext(SdkDispatchers.IO) { + scanBlockBatch( + pathCacheDb, + pathDataDb, + limit, + networkId = network.id + ) + } } else { - scanBlocks(pathCacheDb, pathDataDb, networkId = network.id) + withContext(SdkDispatchers.IO) { + scanBlocks( + pathCacheDb, + pathDataDb, + networkId = network.id + ) + } } } - override fun decryptAndStoreTransaction(tx: ByteArray) = decryptAndStoreTransaction(pathDataDb, tx, networkId = network.id) + override suspend fun decryptAndStoreTransaction(tx: ByteArray) = withContext(SdkDispatchers.IO) { + decryptAndStoreTransaction( + pathDataDb, + tx, + networkId = network.id + ) + } - override fun createToAddress( + override suspend fun createToAddress( consensusBranchId: Long, account: Int, extsk: String, to: String, value: Long, memo: ByteArray? - ): Long = createToAddress( - pathDataDb, - consensusBranchId, - account, - extsk, - to, - value, - memo ?: ByteArray(0), - "$pathParamsDir/$SPEND_PARAM_FILE_NAME", - "$pathParamsDir/$OUTPUT_PARAM_FILE_NAME", - networkId = network.id, - ) - - override fun shieldToAddress( - extsk: String, - tsk: String, - memo: ByteArray? - ): Long { - twig("TMP: shieldToAddress with db path: $pathDataDb, ${memo?.size}") - return shieldToAddress( + ): Long = withContext(SdkDispatchers.IO) { + createToAddress( pathDataDb, - 0, + consensusBranchId, + account, extsk, - tsk, + to, + value, memo ?: ByteArray(0), "$pathParamsDir/$SPEND_PARAM_FILE_NAME", "$pathParamsDir/$OUTPUT_PARAM_FILE_NAME", @@ -154,31 +215,84 @@ class RustBackend private constructor() : RustBackendWelding { ) } - override fun putUtxo( + override suspend fun shieldToAddress( + extsk: String, + tsk: String, + memo: ByteArray? + ): Long { + twig("TMP: shieldToAddress with db path: $pathDataDb, ${memo?.size}") + return withContext(SdkDispatchers.IO) { + shieldToAddress( + pathDataDb, + 0, + extsk, + tsk, + memo ?: ByteArray(0), + "$pathParamsDir/$SPEND_PARAM_FILE_NAME", + "$pathParamsDir/$OUTPUT_PARAM_FILE_NAME", + networkId = network.id, + ) + } + } + + override suspend fun putUtxo( tAddress: String, txId: ByteArray, index: Int, script: ByteArray, value: Long, height: Int - ): Boolean = putUtxo(pathDataDb, tAddress, txId, index, script, value, height, networkId = network.id) + ): Boolean = withContext(SdkDispatchers.IO) { + putUtxo( + pathDataDb, + tAddress, + txId, + index, + script, + value, + height, + networkId = network.id + ) + } - override fun clearUtxos( + override suspend fun clearUtxos( tAddress: String, aboveHeight: Int, - ): Boolean = clearUtxos(pathDataDb, tAddress, aboveHeight, networkId = network.id) + ): Boolean = withContext(SdkDispatchers.IO) { + clearUtxos( + pathDataDb, + tAddress, + aboveHeight, + networkId = network.id + ) + } - override fun getDownloadedUtxoBalance(address: String): WalletBalance { - val verified = getVerifiedTransparentBalance(pathDataDb, address, networkId = network.id) - val total = getTotalTransparentBalance(pathDataDb, address, networkId = network.id) + override suspend fun getDownloadedUtxoBalance(address: String): WalletBalance { + val verified = withContext(SdkDispatchers.IO) { + getVerifiedTransparentBalance( + pathDataDb, + address, + networkId = network.id + ) + } + val total = withContext(SdkDispatchers.IO) { + getTotalTransparentBalance( + pathDataDb, + address, + networkId = network.id + ) + } return WalletBalance(total, verified) } - override fun isValidShieldedAddr(addr: String) = isValidShieldedAddress(addr, networkId = network.id) + override fun isValidShieldedAddr(addr: String) = + isValidShieldedAddress(addr, networkId = network.id) - override fun isValidTransparentAddr(addr: String) = isValidTransparentAddress(addr, networkId = network.id) + override fun isValidTransparentAddr(addr: String) = + isValidTransparentAddress(addr, networkId = network.id) - override fun getBranchIdForHeight(height: Int): Long = branchIdForHeight(height, networkId = network.id) + override fun getBranchIdForHeight(height: Int): Long = + branchIdForHeight(height, networkId = network.id) // /** // * This is a proof-of-concept for doing Local RPC, where we are effectively using the JNI @@ -203,19 +317,21 @@ class RustBackend private constructor() : RustBackendWelding { * Exposes all of the librustzcash functions along with helpers for loading the static library. */ companion object { - private var loaded = false + internal val rustLibraryLoader = NativeLibraryLoader("zcashwalletsdk") /** * Loads the library and initializes path variables. Although it is best to only call this * function once, it is idempotent. */ - fun init( + suspend fun init( cacheDbPath: String, dataDbPath: String, paramsPath: String, zcashNetwork: ZcashNetwork, birthdayHeight: Int? = null ): RustBackend { + rustLibraryLoader.load() + return RustBackend().apply { pathCacheDb = cacheDbPath pathDataDb = dataDbPath @@ -227,16 +343,6 @@ class RustBackend private constructor() : RustBackendWelding { } } - fun load() { - // It is safe to call these things twice but not efficient. So we add a loose check and - // ignore the fact that it's not thread-safe. - if (!loaded) { - twig("Loading RustBackend") { - loadRustLibrary() - } - } - } - /** * Forwards Rust logs to logcat. This is a function that is intended for debug purposes. All * logs will be tagged with `cash.z.rust.logs`. Typically, a developer would clone @@ -249,33 +355,24 @@ class RustBackend private constructor() : RustBackendWelding { */ fun enableRustLogs() = initLogs() - /** - * The first call made to this object in order to load the Rust backend library. All other - * external function calls will fail if the libraries have not been loaded. - */ - private fun loadRustLibrary() { - try { - System.loadLibrary("zcashwalletsdk") - loaded = true - } catch (e: Throwable) { - twig("Error while loading native library: ${e.message}") - } - } // // External Functions // - @JvmStatic private external fun initDataDb(dbDataPath: String, networkId: Int): Boolean + @JvmStatic + private external fun initDataDb(dbDataPath: String, networkId: Int): Boolean - @JvmStatic private external fun initAccountsTableWithKeys( + @JvmStatic + private external fun initAccountsTableWithKeys( dbDataPath: String, extfvk: Array, extpub: Array, networkId: Int, ): Boolean - @JvmStatic private external fun initBlocksTable( + @JvmStatic + private external fun initBlocksTable( dbDataPath: String, height: Int, hash: String, @@ -364,7 +461,8 @@ class RustBackend private constructor() : RustBackendWelding { networkId: Int, ) - @JvmStatic private external fun createToAddress( + @JvmStatic + private external fun createToAddress( dbDataPath: String, consensusBranchId: Long, account: Int, @@ -377,7 +475,8 @@ class RustBackend private constructor() : RustBackendWelding { networkId: Int, ): Long - @JvmStatic private external fun shieldToAddress( + @JvmStatic + private external fun shieldToAddress( dbDataPath: String, account: Int, extsk: String, @@ -388,11 +487,14 @@ class RustBackend private constructor() : RustBackendWelding { networkId: Int, ): Long - @JvmStatic private external fun initLogs() + @JvmStatic + private external fun initLogs() - @JvmStatic private external fun branchIdForHeight(height: Int, networkId: Int): Long + @JvmStatic + private external fun branchIdForHeight(height: Int, networkId: Int): Long - @JvmStatic private external fun putUtxo( + @JvmStatic + private external fun putUtxo( dbDataPath: String, tAddress: String, txId: ByteArray, @@ -403,23 +505,27 @@ class RustBackend private constructor() : RustBackendWelding { networkId: Int, ): Boolean - @JvmStatic private external fun clearUtxos( + @JvmStatic + private external fun clearUtxos( dbDataPath: String, tAddress: String, aboveHeight: Int, networkId: Int, ): Boolean - @JvmStatic private external fun getVerifiedTransparentBalance( + @JvmStatic + private external fun getVerifiedTransparentBalance( pathDataDb: String, taddr: String, networkId: Int, ): Long - @JvmStatic private external fun getTotalTransparentBalance( + @JvmStatic + private external fun getTotalTransparentBalance( pathDataDb: String, taddr: String, networkId: Int, ): Long } } + diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/RustBackendWelding.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/RustBackendWelding.kt index d1595601..82268809 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/RustBackendWelding.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/jni/RustBackendWelding.kt @@ -14,7 +14,7 @@ interface RustBackendWelding { val network: ZcashNetwork - fun createToAddress( + suspend fun createToAddress( consensusBranchId: Long, account: Int, extsk: String, @@ -23,51 +23,51 @@ interface RustBackendWelding { memo: ByteArray? = byteArrayOf() ): Long - fun shieldToAddress( + suspend fun shieldToAddress( extsk: String, tsk: String, memo: ByteArray? = byteArrayOf() ): Long - fun decryptAndStoreTransaction(tx: ByteArray) + suspend fun decryptAndStoreTransaction(tx: ByteArray) - fun initAccountsTable(seed: ByteArray, numberOfAccounts: Int): Array + suspend fun initAccountsTable(seed: ByteArray, numberOfAccounts: Int): Array - fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean + suspend fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean - fun initBlocksTable(height: Int, hash: String, time: Long, saplingTree: String): Boolean + suspend fun initBlocksTable(height: Int, hash: String, time: Long, saplingTree: String): Boolean - fun initDataDb(): Boolean + suspend fun initDataDb(): Boolean fun isValidShieldedAddr(addr: String): Boolean fun isValidTransparentAddr(addr: String): Boolean - fun getShieldedAddress(account: Int = 0): String + suspend fun getShieldedAddress(account: Int = 0): String - fun getTransparentAddress(account: Int = 0, index: Int = 0): String + suspend fun getTransparentAddress(account: Int = 0, index: Int = 0): String - fun getBalance(account: Int = 0): Long + suspend fun getBalance(account: Int = 0): Long fun getBranchIdForHeight(height: Int): Long - fun getReceivedMemoAsUtf8(idNote: Long): String + suspend fun getReceivedMemoAsUtf8(idNote: Long): String - fun getSentMemoAsUtf8(idNote: Long): String + suspend fun getSentMemoAsUtf8(idNote: Long): String - fun getVerifiedBalance(account: Int = 0): Long + suspend fun getVerifiedBalance(account: Int = 0): Long // fun parseTransactionDataList(tdl: LocalRpcTypes.TransactionDataList): LocalRpcTypes.TransparentTransactionList - fun getNearestRewindHeight(height: Int): Int + suspend fun getNearestRewindHeight(height: Int): Int - fun rewindToHeight(height: Int): Boolean + suspend fun rewindToHeight(height: Int): Boolean - fun scanBlocks(limit: Int = -1): Boolean + suspend fun scanBlocks(limit: Int = -1): Boolean - fun validateCombinedChain(): Int + suspend fun validateCombinedChain(): Int - fun putUtxo( + suspend fun putUtxo( tAddress: String, txId: ByteArray, index: Int, @@ -76,59 +76,59 @@ interface RustBackendWelding { height: Int ): Boolean - fun clearUtxos(tAddress: String, aboveHeight: Int = network.saplingActivationHeight - 1): Boolean + suspend fun clearUtxos(tAddress: String, aboveHeight: Int = network.saplingActivationHeight - 1): Boolean - fun getDownloadedUtxoBalance(address: String): WalletBalance + suspend fun getDownloadedUtxoBalance(address: String): WalletBalance // Implemented by `DerivationTool` interface Derivation { - fun deriveShieldedAddress( + suspend fun deriveShieldedAddress( viewingKey: String, network: ZcashNetwork ): String - fun deriveShieldedAddress( + suspend fun deriveShieldedAddress( seed: ByteArray, network: ZcashNetwork, accountIndex: Int = 0, ): String - fun deriveSpendingKeys( + suspend fun deriveSpendingKeys( seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int = 1, ): Array - fun deriveTransparentAddress( + suspend fun deriveTransparentAddress( seed: ByteArray, network: ZcashNetwork, account: Int = 0, index: Int = 0, ): String - fun deriveTransparentAddressFromPublicKey( + suspend fun deriveTransparentAddressFromPublicKey( publicKey: String, network: ZcashNetwork ): String - fun deriveTransparentAddressFromPrivateKey( + suspend fun deriveTransparentAddressFromPrivateKey( privateKey: String, network: ZcashNetwork ): String - fun deriveTransparentSecretKey( + suspend fun deriveTransparentSecretKey( seed: ByteArray, network: ZcashNetwork, account: Int = 0, index: Int = 0, ): String - fun deriveViewingKey( + suspend fun deriveViewingKey( spendingKey: String, network: ZcashNetwork ): String - fun deriveUnifiedViewingKeys( + suspend fun deriveUnifiedViewingKeys( seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int = 1, diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/tool/DerivationTool.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/tool/DerivationTool.kt index 4a228bea..8a8ed9e5 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/tool/DerivationTool.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/tool/DerivationTool.kt @@ -18,7 +18,7 @@ class DerivationTool { * * @return the viewing keys that correspond to the seed, formatted as Strings. */ - override fun deriveUnifiedViewingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array = + override suspend fun deriveUnifiedViewingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array = withRustBackendLoaded { deriveUnifiedViewingKeysFromSeed(seed, numberOfAccounts, networkId = network.id).map { UnifiedViewingKey(it[0], it[1]) @@ -32,7 +32,7 @@ class DerivationTool { * * @return the viewing key that corresponds to the spending key. */ - override fun deriveViewingKey(spendingKey: String, network: ZcashNetwork): String = withRustBackendLoaded { + override suspend fun deriveViewingKey(spendingKey: String, network: ZcashNetwork): String = withRustBackendLoaded { deriveExtendedFullViewingKey(spendingKey, networkId = network.id) } @@ -45,7 +45,7 @@ class DerivationTool { * * @return the spending keys that correspond to the seed, formatted as Strings. */ - override fun deriveSpendingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array = + override suspend fun deriveSpendingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array = withRustBackendLoaded { deriveExtendedSpendingKeys(seed, numberOfAccounts, networkId = network.id) } @@ -59,7 +59,7 @@ class DerivationTool { * * @return the address that corresponds to the seed and account index. */ - override fun deriveShieldedAddress(seed: ByteArray, network: ZcashNetwork, accountIndex: Int): String = + override suspend fun deriveShieldedAddress(seed: ByteArray, network: ZcashNetwork, accountIndex: Int): String = withRustBackendLoaded { deriveShieldedAddressFromSeed(seed, accountIndex, networkId = network.id) } @@ -72,26 +72,26 @@ class DerivationTool { * * @return the address that corresponds to the viewing key. */ - override fun deriveShieldedAddress(viewingKey: String, network: ZcashNetwork): String = withRustBackendLoaded { + override suspend fun deriveShieldedAddress(viewingKey: String, network: ZcashNetwork): String = withRustBackendLoaded { deriveShieldedAddressFromViewingKey(viewingKey, networkId = network.id) } // WIP probably shouldn't be used just yet. Why? // - because we need the private key associated with this seed and this function doesn't return it. // - the underlying implementation needs to be split out into a few lower-level calls - override fun deriveTransparentAddress(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded { + override suspend fun deriveTransparentAddress(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded { deriveTransparentAddressFromSeed(seed, account, index, networkId = network.id) } - override fun deriveTransparentAddressFromPublicKey(transparentPublicKey: String, network: ZcashNetwork): String = withRustBackendLoaded { + override suspend fun deriveTransparentAddressFromPublicKey(transparentPublicKey: String, network: ZcashNetwork): String = withRustBackendLoaded { deriveTransparentAddressFromPubKey(transparentPublicKey, networkId = network.id) } - override fun deriveTransparentAddressFromPrivateKey(transparentPrivateKey: String, network: ZcashNetwork): String = withRustBackendLoaded { + override suspend fun deriveTransparentAddressFromPrivateKey(transparentPrivateKey: String, network: ZcashNetwork): String = withRustBackendLoaded { deriveTransparentAddressFromPrivKey(transparentPrivateKey, networkId = network.id) } - override fun deriveTransparentSecretKey(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded { + override suspend fun deriveTransparentSecretKey(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded { deriveTransparentSecretKeyFromSeed(seed, account, index, networkId = network.id) } @@ -104,8 +104,8 @@ class DerivationTool { * class attempts to interact with it, indirectly, by invoking JNI functions. It would be * nice to have an annotation like @UsesSystemLibrary for this */ - private fun withRustBackendLoaded(block: () -> T): T { - RustBackend.load() + private suspend fun withRustBackendLoaded(block: () -> T): T { + RustBackend.rustLibraryLoader.load() return block() } diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/tool/WalletBirthdayTool.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/tool/WalletBirthdayTool.kt index 4eca01ae..6274383a 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/tool/WalletBirthdayTool.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/tool/WalletBirthdayTool.kt @@ -8,165 +8,156 @@ import cash.z.ecc.android.sdk.type.WalletBirthday import cash.z.ecc.android.sdk.type.ZcashNetwork import com.google.gson.Gson import com.google.gson.stream.JsonReader +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext import java.io.IOException import java.io.InputStreamReader import java.util.Locale /** * Tool for loading checkpoints for the wallet, based on the height at which the wallet was born. - * - * @param appContext needed for loading checkpoints from the app's assets directory. */ -class WalletBirthdayTool(appContext: Context) { - private val context = appContext.applicationContext +object WalletBirthdayTool { + + // Behavior change implemented as a fix for issue #270. Temporarily adding a boolean + // that allows the change to be rolled back quickly if needed, although long-term + // this flag should be removed. + @VisibleForTesting + internal val IS_FALLBACK_ON_FAILURE = true /** * Load the nearest checkpoint to the given birthday height. If null is given, then this * will load the most recent checkpoint available. */ - fun loadNearest(network: ZcashNetwork, birthdayHeight: Int? = null): WalletBirthday { + suspend fun loadNearest( + context: Context, + network: ZcashNetwork, + birthdayHeight: Int? = null + ): WalletBirthday { + // TODO: potentially pull from shared preferences first return loadBirthdayFromAssets(context, network, birthdayHeight) } - companion object { - - // Behavior change implemented as a fix for issue #270. Temporarily adding a boolean - // that allows the change to be rolled back quickly if needed, although long-term - // this flag should be removed. - @VisibleForTesting - internal val IS_FALLBACK_ON_FAILURE = true - - /** - * Load the nearest checkpoint to the given birthday height. If null is given, then this - * will load the most recent checkpoint available. - */ - fun loadNearest( - context: Context, - network: ZcashNetwork, - birthdayHeight: Int? = null - ): WalletBirthday { - // TODO: potentially pull from shared preferences first - return loadBirthdayFromAssets(context, network, birthdayHeight) - } - - /** - * Useful for when an exact checkpoint is needed, like for SAPLING_ACTIVATION_HEIGHT. In - * most cases, loading the nearest checkpoint is preferred for privacy reasons. - */ - fun loadExact(context: Context, network: ZcashNetwork, birthdayHeight: Int) = - loadNearest(context, network, birthdayHeight).also { - if (it.height != birthdayHeight) - throw BirthdayException.ExactBirthdayNotFoundException( - birthdayHeight, - it.height - ) - } - - // TODO: This method performs disk IO; convert to suspending function - // Converting this to suspending will then propagate - @Throws(IOException::class) - internal fun listBirthdayDirectoryContents(context: Context, directory: String) = - context.assets.list(directory) - - /** - * Returns the directory within the assets folder where birthday data - * (i.e. sapling trees for a given height) can be found. - */ - @VisibleForTesting - internal fun birthdayDirectory(network: ZcashNetwork) = - "saplingtree/${(network.networkName as java.lang.String).toLowerCase(Locale.US)}" - - internal fun birthdayHeight(fileName: String) = fileName.split('.').first().toInt() - - private fun Array.sortDescending() = - apply { sortByDescending { birthdayHeight(it) } } - - /** - * Load the given birthday file from the assets of the given context. When no height is - * specified, we default to the file with the greatest name. - * - * @param context the context from which to load assets. - * @param birthdayHeight the height file to look for among the file names. - * - * @return a WalletBirthday that reflects the contents of the file or an exception when - * parsing fails. - */ - private fun loadBirthdayFromAssets( - context: Context, - network: ZcashNetwork, - birthdayHeight: Int? = null - ): WalletBirthday { - twig("loading birthday from assets: $birthdayHeight") - val directory = birthdayDirectory(network) - val treeFiles = getFilteredFileNames(context, directory, birthdayHeight) - - twig("found ${treeFiles.size} sapling tree checkpoints: $treeFiles") - - return getFirstValidWalletBirthday(context, directory, treeFiles) - } - - private fun getFilteredFileNames( - context: Context, - directory: String, - birthdayHeight: Int? = null, - ): List { - val unfilteredTreeFiles = listBirthdayDirectoryContents(context, directory) - if (unfilteredTreeFiles.isNullOrEmpty()) { - throw BirthdayException.MissingBirthdayFilesException(directory) - } - - val filteredTreeFiles = unfilteredTreeFiles - .sortDescending() - .filter { filename -> - birthdayHeight?.let { birthdayHeight(filename) <= it } ?: true - } - - if (filteredTreeFiles.isEmpty()) { - throw BirthdayException.BirthdayFileNotFoundException( - directory, - birthdayHeight + /** + * Useful for when an exact checkpoint is needed, like for SAPLING_ACTIVATION_HEIGHT. In + * most cases, loading the nearest checkpoint is preferred for privacy reasons. + */ + suspend fun loadExact(context: Context, network: ZcashNetwork, birthdayHeight: Int) = + loadNearest(context, network, birthdayHeight).also { + if (it.height != birthdayHeight) + throw BirthdayException.ExactBirthdayNotFoundException( + birthdayHeight, + it.height ) - } - - return filteredTreeFiles } - /** - * @param treeFiles A list of files, sorted in descending order based on `int` value of the first part of the filename. - */ - @VisibleForTesting - internal fun getFirstValidWalletBirthday( - context: Context, - directory: String, - treeFiles: List - ): WalletBirthday { - var lastException: Exception? = null - treeFiles.forEach { treefile -> - try { + // Converting this to suspending will then propagate + @Throws(IOException::class) + internal suspend fun listBirthdayDirectoryContents(context: Context, directory: String) = + withContext(Dispatchers.IO) { + context.assets.list(directory) + } + + /** + * Returns the directory within the assets folder where birthday data + * (i.e. sapling trees for a given height) can be found. + */ + @VisibleForTesting + internal fun birthdayDirectory(network: ZcashNetwork) = + "saplingtree/${(network.networkName as java.lang.String).toLowerCase(Locale.US)}" + + internal fun birthdayHeight(fileName: String) = fileName.split('.').first().toInt() + + private fun Array.sortDescending() = + apply { sortByDescending { birthdayHeight(it) } } + + /** + * Load the given birthday file from the assets of the given context. When no height is + * specified, we default to the file with the greatest name. + * + * @param context the context from which to load assets. + * @param birthdayHeight the height file to look for among the file names. + * + * @return a WalletBirthday that reflects the contents of the file or an exception when + * parsing fails. + */ + private suspend fun loadBirthdayFromAssets( + context: Context, + network: ZcashNetwork, + birthdayHeight: Int? = null + ): WalletBirthday { + twig("loading birthday from assets: $birthdayHeight") + val directory = birthdayDirectory(network) + val treeFiles = getFilteredFileNames(context, directory, birthdayHeight) + + twig("found ${treeFiles.size} sapling tree checkpoints: $treeFiles") + + return getFirstValidWalletBirthday(context, directory, treeFiles) + } + + private suspend fun getFilteredFileNames( + context: Context, + directory: String, + birthdayHeight: Int? = null, + ): List { + val unfilteredTreeFiles = listBirthdayDirectoryContents(context, directory) + if (unfilteredTreeFiles.isNullOrEmpty()) { + throw BirthdayException.MissingBirthdayFilesException(directory) + } + + val filteredTreeFiles = unfilteredTreeFiles + .sortDescending() + .filter { filename -> + birthdayHeight?.let { birthdayHeight(filename) <= it } ?: true + } + + if (filteredTreeFiles.isEmpty()) { + throw BirthdayException.BirthdayFileNotFoundException( + directory, + birthdayHeight + ) + } + + return filteredTreeFiles + } + + /** + * @param treeFiles A list of files, sorted in descending order based on `int` value of the first part of the filename. + */ + @VisibleForTesting + internal suspend fun getFirstValidWalletBirthday( + context: Context, + directory: String, + treeFiles: List + ): WalletBirthday { + var lastException: Exception? = null + treeFiles.forEach { treefile -> + try { + return withContext(Dispatchers.IO) { context.assets.open("$directory/$treefile").use { inputStream -> InputStreamReader(inputStream).use { inputStreamReader -> JsonReader(inputStreamReader).use { jsonReader -> - return Gson().fromJson(jsonReader, WalletBirthday::class.java) + Gson().fromJson(jsonReader, WalletBirthday::class.java) } } } - } catch (t: Throwable) { - val exception = BirthdayException.MalformattedBirthdayFilesException( - directory, - treefile - ) - lastException = exception + } + } catch (t: Throwable) { + val exception = BirthdayException.MalformattedBirthdayFilesException( + directory, + treefile + ) + lastException = exception - if (IS_FALLBACK_ON_FAILURE) { - // TODO: If we ever add crash analytics hooks, this would be something to report - twig("Malformed birthday file $t") - } else { - throw exception - } + if (IS_FALLBACK_ON_FAILURE) { + // TODO: If we ever add crash analytics hooks, this would be something to report + twig("Malformed birthday file $t") + } else { + throw exception } } - - throw lastException!! } + + throw lastException!! } }