[#269] Convert blocking calls to suspend functions

To quickly get this implemented, some calls in the demo-app have been wrapped in `runBlocking {}`.  This is not ideal and will be addressed in followup issues.
This commit is contained in:
Carter Jernigan 2021-10-21 16:05:02 -04:00
parent 07a00dc376
commit 079229972f
37 changed files with 803 additions and 479 deletions

View File

@ -5,6 +5,8 @@ Upcoming Migrating to Version 1.4.* from 1.3.*
-------------------------------------- --------------------------------------
Various APIs that have always been considered private have been moved into a new package called `internal`. While this should not be a breaking change, clients that might have relied on these internal classes should stop doing so. Various APIs that have always been considered private have been moved into a new package called `internal`. While this should not be a breaking change, clients that might have relied on these internal classes should stop doing so.
A number of methods have been converted to suspending functions, because they were performing slow or blocking calls (e.g. disk IO) internally. This is a breaking change.
Migrating to Version 1.3.* from 1.2.* Migrating to Version 1.3.* from 1.2.*
-------------------------------------- --------------------------------------
The biggest breaking changes in 1.3 that inspired incrementing the minor version number was simplifying down to one "network aware" library rather than two separate libraries, each dedicated to either testnet or mainnet. This greatly simplifies the gradle configuration and has lots of other benefits. Wallets can now set a network with code similar to the following: The biggest breaking changes in 1.3 that inspired incrementing the minor version number was simplifying down to one "network aware" library rather than two separate libraries, each dedicated to either testnet or mainnet. This greatly simplifies the gradle configuration and has lots of other benefits. Wallets can now set a network with code similar to the following:

View File

@ -141,8 +141,10 @@ class DarksideTestCoordinator(val wallet: TestWallet) {
inner class DarksideTestValidator { inner class DarksideTestValidator {
fun validateHasBlock(height: Int) { fun validateHasBlock(height: Int) {
assertTrue((synchronizer as SdkSynchronizer).findBlockHashAsHex(height) != null) runBlocking {
assertTrue((synchronizer as SdkSynchronizer).findBlockHash(height)?.size ?: 0 > 0) assertTrue((synchronizer as SdkSynchronizer).findBlockHashAsHex(height) != null)
assertTrue((synchronizer as SdkSynchronizer).findBlockHash(height)?.size ?: 0 > 0)
}
} }
fun validateLatestHeight(height: Int) = runBlocking<Unit> { fun validateLatestHeight(height: Int) = runBlocking<Unit> {
@ -185,7 +187,7 @@ class DarksideTestCoordinator(val wallet: TestWallet) {
} }
fun validateBlockHash(height: Int, expectedHash: String) { fun validateBlockHash(height: Int, expectedHash: String) {
val hash = (synchronizer as SdkSynchronizer).findBlockHashAsHex(height) val hash = runBlocking { (synchronizer as SdkSynchronizer).findBlockHashAsHex(height) }
assertEquals(expectedHash, hash) assertEquals(expectedHash, hash)
} }
@ -194,7 +196,7 @@ class DarksideTestCoordinator(val wallet: TestWallet) {
} }
fun validateTxCount(count: Int) { fun validateTxCount(count: Int) {
val txCount = (synchronizer as SdkSynchronizer).getTransactionCount() val txCount = runBlocking { (synchronizer as SdkSynchronizer).getTransactionCount() }
assertEquals("Expected $count transactions but found $txCount instead!", count, txCount) assertEquals("Expected $count transactions but found $txCount instead!", count, txCount)
} }

View File

@ -23,6 +23,7 @@ import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.takeWhile import kotlinx.coroutines.flow.takeWhile
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.newFixedThreadPoolContext import kotlinx.coroutines.newFixedThreadPoolContext
import kotlinx.coroutines.runBlocking
import java.util.concurrent.TimeoutException import java.util.concurrent.TimeoutException
/** /**
@ -51,19 +52,29 @@ class TestWallet(
val walletScope = CoroutineScope( val walletScope = CoroutineScope(
SupervisorJob() + newFixedThreadPoolContext(3, this.javaClass.simpleName) SupervisorJob() + newFixedThreadPoolContext(3, this.javaClass.simpleName)
) )
// Although runBlocking isn't great, this usage is OK because this is only used within the
// automated tests
private val context = InstrumentationRegistry.getInstrumentation().context private val context = InstrumentationRegistry.getInstrumentation().context
private val seed: ByteArray = Mnemonics.MnemonicCode(seedPhrase).toSeed() private val seed: ByteArray = Mnemonics.MnemonicCode(seedPhrase).toSeed()
private val shieldedSpendingKey = DerivationTool.deriveSpendingKeys(seed, network = network)[0] private val shieldedSpendingKey =
private val transparentSecretKey = DerivationTool.deriveTransparentSecretKey(seed, network = network) runBlocking { DerivationTool.deriveSpendingKeys(seed, network = network)[0] }
val initializer = Initializer(context) { config -> private val transparentSecretKey =
config.importWallet(seed, startHeight, network, host, alias = alias) runBlocking { DerivationTool.deriveTransparentSecretKey(seed, network = network) }
val initializer = runBlocking {
Initializer.new(context) { config ->
runBlocking { config.importWallet(seed, startHeight, network, host, alias = alias) }
}
} }
val synchronizer: SdkSynchronizer = Synchronizer(initializer) as SdkSynchronizer val synchronizer: SdkSynchronizer = Synchronizer(initializer) as SdkSynchronizer
val service = (synchronizer.processor.downloader.lightWalletService as LightWalletGrpcService) val service = (synchronizer.processor.downloader.lightWalletService as LightWalletGrpcService)
val available get() = synchronizer.saplingBalances.value.availableZatoshi val available get() = synchronizer.saplingBalances.value.availableZatoshi
val shieldedAddress = DerivationTool.deriveShieldedAddress(seed, network = network) val shieldedAddress =
val transparentAddress = DerivationTool.deriveTransparentAddress(seed, network = network) runBlocking { DerivationTool.deriveShieldedAddress(seed, network = network) }
val transparentAddress =
runBlocking { DerivationTool.deriveTransparentAddress(seed, network = network) }
val birthdayHeight get() = synchronizer.latestBirthdayHeight val birthdayHeight get() = synchronizer.latestBirthdayHeight
val networkName get() = synchronizer.network.networkName val networkName get() = synchronizer.network.networkName
val connectionInfo get() = service.connectionInfo.toString() val connectionInfo get() = service.connectionInfo.toString()

View File

@ -52,7 +52,12 @@ class SampleCodeTest {
// /////////////////////////////////////////////////// // ///////////////////////////////////////////////////
// Derive Extended Spending Key // Derive Extended Spending Key
@Test fun deriveSpendingKey() { @Test fun deriveSpendingKey() {
val spendingKeys = DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.Mainnet) val spendingKeys = runBlocking {
DerivationTool.deriveSpendingKeys(
seed,
ZcashNetwork.Mainnet
)
}
assertEquals(1, spendingKeys.size) assertEquals(1, spendingKeys.size)
log("Spending Key: ${spendingKeys?.get(0)}") log("Spending Key: ${spendingKeys?.get(0)}")
} }
@ -140,7 +145,7 @@ class SampleCodeTest {
private val lightwalletdHost: String = ZcashNetwork.Mainnet.defaultHost private val lightwalletdHost: String = ZcashNetwork.Mainnet.defaultHost
private val context = InstrumentationRegistry.getInstrumentation().targetContext private val context = InstrumentationRegistry.getInstrumentation().targetContext
private val synchronizer = Synchronizer(Initializer(context) {}) private val synchronizer = Synchronizer(runBlocking { Initializer.new(context) {} })
@BeforeClass @BeforeClass
@JvmStatic @JvmStatic

View File

@ -2,6 +2,7 @@ package cash.z.ecc.android.sdk.demoapp.demos.getaddress
import android.os.Bundle import android.os.Bundle
import android.view.LayoutInflater import android.view.LayoutInflater
import androidx.lifecycle.lifecycleScope
import cash.z.ecc.android.bip39.Mnemonics import cash.z.ecc.android.bip39.Mnemonics
import cash.z.ecc.android.bip39.toSeed import cash.z.ecc.android.bip39.toSeed
import cash.z.ecc.android.sdk.demoapp.BaseDemoFragment import cash.z.ecc.android.sdk.demoapp.BaseDemoFragment
@ -11,6 +12,8 @@ import cash.z.ecc.android.sdk.demoapp.util.fromResources
import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.UnifiedViewingKey import cash.z.ecc.android.sdk.type.UnifiedViewingKey
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
/** /**
* Displays the address associated with the seed defined by the default config. To modify the seed * Displays the address associated with the seed defined by the default config. To modify the seed
@ -34,14 +37,16 @@ class GetAddressFragment : BaseDemoFragment<FragmentGetAddressBinding>() {
seed = Mnemonics.MnemonicCode(seedPhrase).toSeed() seed = Mnemonics.MnemonicCode(seedPhrase).toSeed()
// the derivation tool can be used for generating keys and addresses // the derivation tool can be used for generating keys and addresses
viewingKey = DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() viewingKey = runBlocking { DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() }
} }
private fun displayAddress() { private fun displayAddress() {
// a full fledged app would just get the address from the synchronizer // a full fledged app would just get the address from the synchronizer
val zaddress = DerivationTool.deriveShieldedAddress(seed, ZcashNetwork.fromResources(requireApplicationContext())) viewLifecycleOwner.lifecycleScope.launchWhenStarted {
val taddress = DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext())) val zaddress = DerivationTool.deriveShieldedAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))
binding.textInfo.text = "z-addr:\n$zaddress\n\n\nt-addr:\n$taddress" val taddress = DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))
binding.textInfo.text = "z-addr:\n$zaddress\n\n\nt-addr:\n$taddress"
}
} }
// TODO: show an example with the synchronizer // TODO: show an example with the synchronizer
@ -65,10 +70,15 @@ class GetAddressFragment : BaseDemoFragment<FragmentGetAddressBinding>() {
// //
override fun onActionButtonClicked() { override fun onActionButtonClicked() {
copyToClipboard( viewLifecycleOwner.lifecycleScope.launch {
DerivationTool.deriveShieldedAddress(viewingKey.extfvk, ZcashNetwork.fromResources(requireApplicationContext())), copyToClipboard(
"Shielded address copied to clipboard!" DerivationTool.deriveShieldedAddress(
) viewingKey.extfvk,
ZcashNetwork.fromResources(requireApplicationContext())
),
"Shielded address copied to clipboard!"
)
}
} }
override fun inflateBinding(layoutInflater: LayoutInflater): FragmentGetAddressBinding = override fun inflateBinding(layoutInflater: LayoutInflater): FragmentGetAddressBinding =

View File

@ -17,6 +17,7 @@ import cash.z.ecc.android.sdk.ext.convertZatoshiToZecString
import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.WalletBalance import cash.z.ecc.android.sdk.type.WalletBalance
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
/** /**
* Displays the available balance && total balance associated with the seed defined by the default config. * Displays the available balance && total balance associated with the seed defined by the default config.
@ -43,13 +44,13 @@ class GetBalanceFragment : BaseDemoFragment<FragmentGetBalanceBinding>() {
val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed() val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed()
// converting seed into viewingKey // converting seed into viewingKey
val viewingKey = DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() val viewingKey = runBlocking { DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() }
// using the ViewingKey to initialize // using the ViewingKey to initialize
Initializer(requireApplicationContext()) { runBlocking {Initializer.new(requireApplicationContext(), null) {
it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext())) it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext()))
it.importWallet(viewingKey, network = ZcashNetwork.fromResources(requireApplicationContext())) it.importWallet(viewingKey, network = ZcashNetwork.fromResources(requireApplicationContext()))
}.let { initializer -> }}.let { initializer ->
synchronizer = Synchronizer(initializer) synchronizer = Synchronizer(initializer)
} }
} }

View File

@ -2,6 +2,7 @@ package cash.z.ecc.android.sdk.demoapp.demos.getprivatekey
import android.os.Bundle import android.os.Bundle
import android.view.LayoutInflater import android.view.LayoutInflater
import androidx.lifecycle.lifecycleScope
import cash.z.ecc.android.bip39.Mnemonics import cash.z.ecc.android.bip39.Mnemonics
import cash.z.ecc.android.bip39.toSeed import cash.z.ecc.android.bip39.toSeed
import cash.z.ecc.android.sdk.demoapp.BaseDemoFragment import cash.z.ecc.android.sdk.demoapp.BaseDemoFragment
@ -10,6 +11,7 @@ import cash.z.ecc.android.sdk.demoapp.ext.requireApplicationContext
import cash.z.ecc.android.sdk.demoapp.util.fromResources import cash.z.ecc.android.sdk.demoapp.util.fromResources
import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.launch
/** /**
* Displays the viewing key and spending key associated with the seed used during the demo. The * Displays the viewing key and spending key associated with the seed used during the demo. The
@ -37,13 +39,22 @@ class GetPrivateKeyFragment : BaseDemoFragment<FragmentGetPrivateKeyBinding>() {
private fun displayKeys() { private fun displayKeys() {
// derive the keys from the seed: // derive the keys from the seed:
// demonstrate deriving spending keys for five accounts but only take the first one // demonstrate deriving spending keys for five accounts but only take the first one
val spendingKey = DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext()), 5).first() lifecycleScope.launchWhenStarted {
val spendingKey = DerivationTool.deriveSpendingKeys(
seed,
ZcashNetwork.fromResources(requireApplicationContext()),
5
).first()
// derive the key that allows you to view but not spend transactions // derive the key that allows you to view but not spend transactions
val viewingKey = DerivationTool.deriveViewingKey(spendingKey, ZcashNetwork.fromResources(requireApplicationContext())) val viewingKey = DerivationTool.deriveViewingKey(
spendingKey,
ZcashNetwork.fromResources(requireApplicationContext())
)
// display the keys in the UI // display the keys in the UI
binding.textInfo.setText("Spending Key:\n$spendingKey\n\nViewing Key:\n$viewingKey") binding.textInfo.setText("Spending Key:\n$spendingKey\n\nViewing Key:\n$viewingKey")
}
} }
// //
@ -65,10 +76,15 @@ class GetPrivateKeyFragment : BaseDemoFragment<FragmentGetPrivateKeyBinding>() {
// //
override fun onActionButtonClicked() { override fun onActionButtonClicked() {
copyToClipboard( lifecycleScope.launch {
DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first().extpub, copyToClipboard(
"ViewingKey copied to clipboard!" DerivationTool.deriveUnifiedViewingKeys(
) seed,
ZcashNetwork.fromResources(requireApplicationContext())
).first().extpub,
"ViewingKey copied to clipboard!"
)
}
} }
override fun inflateBinding(layoutInflater: LayoutInflater): FragmentGetPrivateKeyBinding = override fun inflateBinding(layoutInflater: LayoutInflater): FragmentGetPrivateKeyBinding =

View File

@ -19,6 +19,7 @@ import cash.z.ecc.android.sdk.ext.collectWith
import cash.z.ecc.android.sdk.internal.twig import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
/** /**
* List all transactions related to the given seed, since the given birthday. This begins by * List all transactions related to the given seed, since the given birthday. This begins by
@ -47,11 +48,16 @@ class ListTransactionsFragment : BaseDemoFragment<FragmentListTransactionsBindin
// have the seed stored // have the seed stored
val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed() val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed()
initializer = Initializer(requireApplicationContext()) { initializer = runBlocking {Initializer.new(requireApplicationContext()) {
it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) runBlocking { it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) }
it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext())) it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext()))
}}
address = runBlocking {
DerivationTool.deriveShieldedAddress(
seed,
ZcashNetwork.fromResources(requireApplicationContext())
)
} }
address = DerivationTool.deriveShieldedAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))
synchronizer = Synchronizer(initializer) synchronizer = Synchronizer(initializer)
} }

View File

@ -25,6 +25,7 @@ import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
/** /**
@ -60,10 +61,10 @@ class ListUtxosFragment : BaseDemoFragment<FragmentListUtxosBinding>() {
// Use a BIP-39 library to convert a seed phrase into a byte array. Most wallets already // Use a BIP-39 library to convert a seed phrase into a byte array. Most wallets already
// have the seed stored // have the seed stored
seed = Mnemonics.MnemonicCode(sharedViewModel.seedPhrase.value).toSeed() seed = Mnemonics.MnemonicCode(sharedViewModel.seedPhrase.value).toSeed()
initializer = Initializer(requireApplicationContext()) { initializer = runBlocking {Initializer.new(requireApplicationContext()) {
it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) runBlocking { it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) }
it.alias = "Demo_Utxos" it.alias = "Demo_Utxos"
} }}
synchronizer = Synchronizer(initializer) synchronizer = Synchronizer(initializer)
} }
@ -102,7 +103,7 @@ class ListUtxosFragment : BaseDemoFragment<FragmentListUtxosBinding>() {
txids?.map { txids?.map {
it.data.apply { it.data.apply {
try { try {
initializer.rustBackend.decryptAndStoreTransaction(toByteArray()) runBlocking { initializer.rustBackend.decryptAndStoreTransaction(toByteArray()) }
} catch (t: Throwable) { } catch (t: Throwable) {
twig("failed to decrypt and store transaction due to: $t") twig("failed to decrypt and store transaction due to: $t")
} }
@ -154,7 +155,9 @@ class ListUtxosFragment : BaseDemoFragment<FragmentListUtxosBinding>() {
super.onResume() super.onResume()
resetInBackground() resetInBackground()
val seed = Mnemonics.MnemonicCode(sharedViewModel.seedPhrase.value).toSeed() val seed = Mnemonics.MnemonicCode(sharedViewModel.seedPhrase.value).toSeed()
binding.inputAddress.setText(DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))) viewLifecycleOwner.lifecycleScope.launchWhenStarted {
binding.inputAddress.setText(DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext())))
}
} }
var initialCount: Int = 0 var initialCount: Int = 0

View File

@ -32,6 +32,7 @@ import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.WalletBalance import cash.z.ecc.android.sdk.type.WalletBalance
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
/** /**
* Demonstrates sending funds to an address. This is the most complex example that puts all of the * Demonstrates sending funds to an address. This is the most complex example that puts all of the
@ -63,13 +64,13 @@ class SendFragment : BaseDemoFragment<FragmentSendBinding>() {
// have the seed stored // have the seed stored
val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed() val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed()
Initializer(requireApplicationContext()) { runBlocking { Initializer.new(requireApplicationContext()) {
it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) runBlocking { it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) }
it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext())) it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext()))
}.let { initializer -> }}.let { initializer ->
synchronizer = Synchronizer(initializer) synchronizer = Synchronizer(initializer)
} }
spendingKey = DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() spendingKey = runBlocking { DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() }
} }
// //

View File

@ -5,6 +5,7 @@ import androidx.test.core.app.ApplicationProvider
import androidx.test.filters.SmallTest import androidx.test.filters.SmallTest
import cash.z.ecc.android.sdk.tool.WalletBirthdayTool import cash.z.ecc.android.sdk.tool.WalletBirthdayTool
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
import org.json.JSONObject import org.json.JSONObject
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse import org.junit.Assert.assertFalse
@ -92,9 +93,9 @@ class AssetTest {
private data class JsonFile(val jsonObject: JSONObject, val filename: String) private data class JsonFile(val jsonObject: JSONObject, val filename: String)
companion object { companion object {
fun listAssets(network: ZcashNetwork) = WalletBirthdayTool.listBirthdayDirectoryContents( fun listAssets(network: ZcashNetwork) = runBlocking { WalletBirthdayTool.listBirthdayDirectoryContents(
ApplicationProvider.getApplicationContext<Context>(), ApplicationProvider.getApplicationContext<Context>(),
WalletBirthdayTool.birthdayDirectory(network) WalletBirthdayTool.birthdayDirectory(network))
) }
} }
} }

View File

@ -3,13 +3,14 @@ package cash.z.ecc.android.sdk.ext
import cash.z.ecc.android.sdk.Initializer import cash.z.ecc.android.sdk.Initializer
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import cash.z.ecc.android.sdk.util.SimpleMnemonics import cash.z.ecc.android.sdk.util.SimpleMnemonics
import kotlinx.coroutines.runBlocking
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
import okhttp3.Request import okhttp3.Request
import org.json.JSONObject import org.json.JSONObject
import ru.gildor.coroutines.okhttp.await import ru.gildor.coroutines.okhttp.await
fun Initializer.Config.seedPhrase(seedPhrase: String, network: ZcashNetwork) { fun Initializer.Config.seedPhrase(seedPhrase: String, network: ZcashNetwork) {
setSeed(SimpleMnemonics().toSeed(seedPhrase.toCharArray()), network) runBlocking { setSeed(SimpleMnemonics().toSeed(seedPhrase.toCharArray()), network) }
} }
object BlockExplorer { object BlockExplorer {

View File

@ -46,7 +46,13 @@ class TestnetIntegrationTest : ScopedTest() {
@Test @Test
fun testLoadBirthday() { fun testLoadBirthday() {
val (height, hash, time, tree) = WalletBirthdayTool.loadNearest(context, synchronizer.network, saplingActivation + 1) val (height, hash, time, tree) = runBlocking {
WalletBirthdayTool.loadNearest(
context,
synchronizer.network,
saplingActivation + 1
)
}
assertEquals(saplingActivation, height) assertEquals(saplingActivation, height)
} }
@ -118,9 +124,11 @@ class TestnetIntegrationTest : ScopedTest() {
val toAddress = "zs1vp7kvlqr4n9gpehztr76lcn6skkss9p8keqs3nv8avkdtjrcctrvmk9a7u494kluv756jeee5k0" val toAddress = "zs1vp7kvlqr4n9gpehztr76lcn6skkss9p8keqs3nv8avkdtjrcctrvmk9a7u494kluv756jeee5k0"
private val context = InstrumentationRegistry.getInstrumentation().context private val context = InstrumentationRegistry.getInstrumentation().context
private val initializer = Initializer(context) { config -> private val initializer = runBlocking {
config.setNetwork(ZcashNetwork.Testnet, host) Initializer.new(context) { config ->
config.importWallet(seed, birthdayHeight, ZcashNetwork.Testnet) config.setNetwork(ZcashNetwork.Testnet, host)
runBlocking { config.importWallet(seed, birthdayHeight, ZcashNetwork.Testnet) }
}
} }
private lateinit var synchronizer: Synchronizer private lateinit var synchronizer: Synchronizer

View File

@ -3,6 +3,7 @@ package cash.z.ecc.android.sdk.jni
import cash.z.ecc.android.sdk.annotation.MaintainedTest import cash.z.ecc.android.sdk.annotation.MaintainedTest
import cash.z.ecc.android.sdk.annotation.TestPurpose import cash.z.ecc.android.sdk.annotation.TestPurpose
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
@ -43,8 +44,8 @@ class BranchIdTest(
// is an abnormal use of the SDK because this really should run at the rust level // is an abnormal use of the SDK because this really should run at the rust level
// However, due to quirks on certain devices, we created this test at the Android level, // However, due to quirks on certain devices, we created this test at the Android level,
// as a sanity check // as a sanity check
val testnetBackend = RustBackend.init("", "", "", ZcashNetwork.Testnet) val testnetBackend = runBlocking { RustBackend.init("", "", "", ZcashNetwork.Testnet) }
val mainnetBackend = RustBackend.init("", "", "", ZcashNetwork.Mainnet) val mainnetBackend = runBlocking { RustBackend.init("", "", "", ZcashNetwork.Mainnet) }
return listOf( return listOf(
// Mainnet Cases // Mainnet Cases
arrayOf("Sapling", 419_200, 1991772603L, "76b809bb", mainnetBackend), arrayOf("Sapling", 419_200, 1991772603L, "76b809bb", mainnetBackend),

View File

@ -9,6 +9,7 @@ import cash.z.ecc.android.sdk.internal.TroubleshootingTwig
import cash.z.ecc.android.sdk.internal.Twig import cash.z.ecc.android.sdk.internal.Twig
import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.BeforeClass import org.junit.BeforeClass
import org.junit.Test import org.junit.Test
@ -21,23 +22,23 @@ import org.junit.runners.Parameterized
class TransparentTest(val expected: Expected, val network: ZcashNetwork) { class TransparentTest(val expected: Expected, val network: ZcashNetwork) {
@Test @Test
fun deriveTransparentSecretKeyTest() { fun deriveTransparentSecretKeyTest() = runBlocking {
assertEquals(expected.tskCompressed, DerivationTool.deriveTransparentSecretKey(SEED, network = network)) assertEquals(expected.tskCompressed, DerivationTool.deriveTransparentSecretKey(SEED, network = network))
} }
@Test @Test
fun deriveTransparentAddressTest() { fun deriveTransparentAddressTest() = runBlocking {
assertEquals(expected.tAddr, DerivationTool.deriveTransparentAddress(SEED, network = network)) assertEquals(expected.tAddr, DerivationTool.deriveTransparentAddress(SEED, network = network))
} }
@Test @Test
fun deriveTransparentAddressFromSecretKeyTest() { fun deriveTransparentAddressFromSecretKeyTest() = runBlocking {
val pk = DerivationTool.deriveTransparentSecretKey(SEED, network = network) val pk = DerivationTool.deriveTransparentSecretKey(SEED, network = network)
assertEquals(expected.tAddr, DerivationTool.deriveTransparentAddressFromPrivateKey(pk, network = network)) assertEquals(expected.tAddr, DerivationTool.deriveTransparentAddressFromPrivateKey(pk, network = network))
} }
@Test @Test
fun deriveUnifiedViewingKeysFromSeedTest() { fun deriveUnifiedViewingKeysFromSeedTest() = runBlocking {
val uvks = DerivationTool.deriveUnifiedViewingKeys(SEED, network = network) val uvks = DerivationTool.deriveUnifiedViewingKeys(SEED, network = network)
assertEquals(1, uvks.size) assertEquals(1, uvks.size)
val uvk = uvks.first() val uvk = uvks.first()

View File

@ -4,7 +4,8 @@ import android.content.Context
import androidx.test.core.app.ApplicationProvider import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.SmallTest import androidx.test.filters.SmallTest
import cash.z.ecc.android.sdk.tool.WalletBirthdayTool.Companion.IS_FALLBACK_ON_FAILURE import cash.z.ecc.android.sdk.tool.WalletBirthdayTool.IS_FALLBACK_ON_FAILURE
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
@ -25,11 +26,13 @@ class WalletBirthdayToolTest {
val directory = "saplingtree/goodnet" val directory = "saplingtree/goodnet"
val context = ApplicationProvider.getApplicationContext<Context>() val context = ApplicationProvider.getApplicationContext<Context>()
val birthday = WalletBirthdayTool.getFirstValidWalletBirthday( val birthday = runBlocking {
context, WalletBirthdayTool.getFirstValidWalletBirthday(
directory, context,
listOf("1300000.json", "1290000.json") directory,
) listOf("1300000.json", "1290000.json")
)
}
assertEquals(1300000, birthday.height) assertEquals(1300000, birthday.height)
} }
@ -42,11 +45,13 @@ class WalletBirthdayToolTest {
val directory = "saplingtree/badnet" val directory = "saplingtree/badnet"
val context = ApplicationProvider.getApplicationContext<Context>() val context = ApplicationProvider.getApplicationContext<Context>()
val birthday = WalletBirthdayTool.getFirstValidWalletBirthday( val birthday = runBlocking {
context, WalletBirthdayTool.getFirstValidWalletBirthday(
directory, context,
listOf("1300000.json", "1290000.json") directory,
) listOf("1300000.json", "1290000.json")
)
}
assertEquals(1290000, birthday.height) assertEquals(1290000, birthday.height)
} }
} }

View File

@ -5,6 +5,7 @@ import cash.z.ecc.android.sdk.Initializer
import cash.z.ecc.android.sdk.Synchronizer import cash.z.ecc.android.sdk.Synchronizer
import cash.z.ecc.android.sdk.internal.TroubleshootingTwig import cash.z.ecc.android.sdk.internal.TroubleshootingTwig
import cash.z.ecc.android.sdk.internal.Twig import cash.z.ecc.android.sdk.internal.Twig
import cash.z.ecc.android.sdk.internal.ext.deleteSuspend
import cash.z.ecc.android.sdk.internal.twig import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.tool.WalletBirthdayTool import cash.z.ecc.android.sdk.tool.WalletBirthdayTool
import cash.z.ecc.android.sdk.type.WalletBirthday import cash.z.ecc.android.sdk.type.WalletBirthday
@ -52,7 +53,7 @@ class BalancePrinterUtil {
fun setup() { fun setup() {
Twig.plant(TroubleshootingTwig()) Twig.plant(TroubleshootingTwig())
cacheBlocks() cacheBlocks()
birthday = WalletBirthdayTool.loadNearest(context, network, birthdayHeight) birthday = runBlocking { WalletBirthdayTool.loadNearest(context, network, birthdayHeight) }
} }
private fun cacheBlocks() = runBlocking { private fun cacheBlocks() = runBlocking {
@ -66,8 +67,8 @@ class BalancePrinterUtil {
// assertEquals(-1, error) // assertEquals(-1, error)
} }
private fun deleteDb(dbName: String) { private suspend fun deleteDb(dbName: String) {
context.getDatabasePath(dbName).absoluteFile.delete() context.getDatabasePath(dbName).absoluteFile.deleteSuspend()
} }
@Test @Test
@ -79,8 +80,8 @@ class BalancePrinterUtil {
mnemonics.toSeed(seedPhrase.toCharArray()) mnemonics.toSeed(seedPhrase.toCharArray())
}.collect { seed -> }.collect { seed ->
// TODO: clear the dataDb but leave the cacheDb // TODO: clear the dataDb but leave the cacheDb
val initializer = Initializer(context) { config -> val initializer = Initializer.new(context) { config ->
config.importWallet(seed, birthdayHeight, network) runBlocking { config.importWallet(seed, birthdayHeight, network) }
config.setNetwork(network) config.setNetwork(network)
config.alias = alias config.alias = alias
} }

View File

@ -64,7 +64,13 @@ class DataDbScannerUtil {
@Test @Test
@Ignore("This test is broken") @Ignore("This test is broken")
fun scanExistingDb() { fun scanExistingDb() {
synchronizer = Synchronizer(Initializer(context) { it.setBirthdayHeight(birthdayHeight) }) synchronizer = Synchronizer(runBlocking {
Initializer.new(context) {
it.setBirthdayHeight(
birthdayHeight
)
}
})
println("sync!") println("sync!")
synchronizer.start() synchronizer.start()

View File

@ -23,6 +23,7 @@ import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.takeWhile import kotlinx.coroutines.flow.takeWhile
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.newFixedThreadPoolContext import kotlinx.coroutines.newFixedThreadPoolContext
import kotlinx.coroutines.runBlocking
import java.util.concurrent.TimeoutException import java.util.concurrent.TimeoutException
/** /**
@ -51,19 +52,29 @@ class TestWallet(
val walletScope = CoroutineScope( val walletScope = CoroutineScope(
SupervisorJob() + newFixedThreadPoolContext(3, this.javaClass.simpleName) SupervisorJob() + newFixedThreadPoolContext(3, this.javaClass.simpleName)
) )
// Although runBlocking isn't great, this usage is OK because this is only used within the
// automated tests
private val context = InstrumentationRegistry.getInstrumentation().context private val context = InstrumentationRegistry.getInstrumentation().context
private val seed: ByteArray = Mnemonics.MnemonicCode(seedPhrase).toSeed() private val seed: ByteArray = Mnemonics.MnemonicCode(seedPhrase).toSeed()
private val shieldedSpendingKey = DerivationTool.deriveSpendingKeys(seed, network = network)[0] private val shieldedSpendingKey =
private val transparentSecretKey = DerivationTool.deriveTransparentSecretKey(seed, network = network) runBlocking { DerivationTool.deriveSpendingKeys(seed, network = network)[0] }
val initializer = Initializer(context) { config -> private val transparentSecretKey =
config.importWallet(seed, startHeight, network, host, alias = alias) runBlocking { DerivationTool.deriveTransparentSecretKey(seed, network = network) }
val initializer = runBlocking {
Initializer.new(context) { config ->
runBlocking { config.importWallet(seed, startHeight, network, host, alias = alias) }
}
} }
val synchronizer: SdkSynchronizer = Synchronizer(initializer) as SdkSynchronizer val synchronizer: SdkSynchronizer = Synchronizer(initializer) as SdkSynchronizer
val service = (synchronizer.processor.downloader.lightWalletService as LightWalletGrpcService) val service = (synchronizer.processor.downloader.lightWalletService as LightWalletGrpcService)
val available get() = synchronizer.saplingBalances.value.availableZatoshi val available get() = synchronizer.saplingBalances.value.availableZatoshi
val shieldedAddress = DerivationTool.deriveShieldedAddress(seed, network = network) val shieldedAddress =
val transparentAddress = DerivationTool.deriveTransparentAddress(seed, network = network) runBlocking { DerivationTool.deriveShieldedAddress(seed, network = network) }
val transparentAddress =
runBlocking { DerivationTool.deriveTransparentAddress(seed, network = network) }
val birthdayHeight get() = synchronizer.latestBirthdayHeight val birthdayHeight get() = synchronizer.latestBirthdayHeight
val networkName get() = synchronizer.network.networkName val networkName get() = synchronizer.network.networkName
val connectionInfo get() = service.connectionInfo.toString() val connectionInfo get() = service.connectionInfo.toString()

View File

@ -4,95 +4,37 @@ import android.content.Context
import cash.z.ecc.android.sdk.exception.InitializerException import cash.z.ecc.android.sdk.exception.InitializerException
import cash.z.ecc.android.sdk.ext.ZcashSdk import cash.z.ecc.android.sdk.ext.ZcashSdk
import cash.z.ecc.android.sdk.internal.twig import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.internal.SdkDispatchers
import cash.z.ecc.android.sdk.internal.ext.getCacheDirSuspend
import cash.z.ecc.android.sdk.internal.ext.getDatabasePathSuspend
import cash.z.ecc.android.sdk.jni.RustBackend import cash.z.ecc.android.sdk.jni.RustBackend
import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.tool.WalletBirthdayTool import cash.z.ecc.android.sdk.tool.WalletBirthdayTool
import cash.z.ecc.android.sdk.type.UnifiedViewingKey import cash.z.ecc.android.sdk.type.UnifiedViewingKey
import cash.z.ecc.android.sdk.type.WalletBirthday import cash.z.ecc.android.sdk.type.WalletBirthday
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File import java.io.File
/** /**
* Simplified Initializer focused on starting from a ViewingKey. * Simplified Initializer focused on starting from a ViewingKey.
*/ */
class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Throwable?) -> Boolean)? = null, config: Config) { class Initializer private constructor(
val context = appContext.applicationContext val context: Context,
val rustBackend: RustBackend val rustBackend: RustBackend,
val network: ZcashNetwork val network: ZcashNetwork,
val alias: String val alias: String,
val host: String val host: String,
val port: Int val port: Int,
val viewingKeys: List<UnifiedViewingKey> val viewingKeys: List<UnifiedViewingKey>,
val overwriteVks: Boolean val overwriteVks: Boolean,
val birthday: WalletBirthday val birthday: WalletBirthday
) {
/** suspend fun erase() = erase(context, network, alias)
* A callback to invoke whenever an uncaught error is encountered. By definition, the return
* value of the function is ignored because this error is unrecoverable. The only reason the
* function has a return value is so that all error handlers work with the same signature which
* allows one function to handle all errors in simple apps.
*/
var onCriticalErrorHandler: ((Throwable?) -> Boolean)? = onCriticalErrorHandler
init { class Config private constructor(
try {
config.validate()
network = config.network
val heightToUse = config.birthdayHeight
?: (if (config.defaultToOldestHeight == true) network.saplingActivationHeight else null)
val loadedBirthday = WalletBirthdayTool.loadNearest(context, network, heightToUse)
birthday = loadedBirthday
viewingKeys = config.viewingKeys
overwriteVks = config.overwriteVks
alias = config.alias
host = config.host
port = config.port
rustBackend = initRustBackend(network, birthday)
} catch (t: Throwable) {
onCriticalError(t)
throw t
}
}
constructor(appContext: Context, config: Config) : this(appContext, null, config)
constructor(appContext: Context, onCriticalErrorHandler: ((Throwable?) -> Boolean)? = null, block: (Config) -> Unit) : this(appContext, onCriticalErrorHandler, Config(block))
fun erase() = erase(context, network, alias)
private fun initRustBackend(network: ZcashNetwork, birthday: WalletBirthday): RustBackend {
return RustBackend.init(
cacheDbPath(context, network, alias),
dataDbPath(context, network, alias),
"${context.cacheDir.absolutePath}/params",
network,
birthday.height
)
}
private fun onCriticalError(error: Throwable) {
twig("********")
twig("******** INITIALIZER ERROR: $error")
if (error.cause != null) twig("******** caused by ${error.cause}")
if (error.cause?.cause != null) twig("******** caused by ${error.cause?.cause}")
twig("********")
twig(error)
if (onCriticalErrorHandler == null) {
twig(
"WARNING: a critical error occurred on the Initializer but no callback is " +
"registered to be notified of critical errors! THIS IS PROBABLY A MISTAKE. To " +
"respond to these errors (perhaps to update the UI or alert the user) set " +
"initializer.onCriticalErrorHandler to a non-null value or use the secondary " +
"constructor: Initializer(context, handler) { ... }. Note that the synchronizer " +
"and initializer BOTH have error handlers and since the initializer exists " +
"before the synchronizer, it needs its error handler set separately."
)
}
onCriticalErrorHandler?.invoke(error)
}
class Config private constructor (
val viewingKeys: MutableList<UnifiedViewingKey> = mutableListOf(), val viewingKeys: MutableList<UnifiedViewingKey> = mutableListOf(),
var alias: String = ZcashSdk.DEFAULT_ALIAS, var alias: String = ZcashSdk.DEFAULT_ALIAS,
) { ) {
@ -177,7 +119,10 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* is not currently well supported. Consider it an alpha-preview feature that might work but * is not currently well supported. Consider it an alpha-preview feature that might work but
* probably has serious bugs. * probably has serious bugs.
*/ */
fun setViewingKeys(vararg unifiedViewingKeys: UnifiedViewingKey, overwrite: Boolean = false): Config = apply { fun setViewingKeys(
vararg unifiedViewingKeys: UnifiedViewingKey,
overwrite: Boolean = false
): Config = apply {
overwriteVks = overwrite overwriteVks = overwrite
viewingKeys.apply { viewingKeys.apply {
clear() clear()
@ -225,7 +170,7 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
/** /**
* Import a wallet using the first viewing key derived from the given seed. * Import a wallet using the first viewing key derived from the given seed.
*/ */
fun importWallet( suspend fun importWallet(
seed: ByteArray, seed: ByteArray,
birthdayHeight: Int? = null, birthdayHeight: Int? = null,
network: ZcashNetwork, network: ZcashNetwork,
@ -262,7 +207,7 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
/** /**
* Create a new wallet using the first viewing key derived from the given seed. * Create a new wallet using the first viewing key derived from the given seed.
*/ */
fun newWallet( suspend fun newWallet(
seed: ByteArray, seed: ByteArray,
network: ZcashNetwork, network: ZcashNetwork,
host: String = network.defaultHost, host: String = network.defaultHost,
@ -296,9 +241,20 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* Convenience method for setting thew viewingKeys from a given seed. This is the same as * Convenience method for setting thew viewingKeys from a given seed. This is the same as
* calling `setViewingKeys` with the keys that match this seed. * calling `setViewingKeys` with the keys that match this seed.
*/ */
fun setSeed(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int = 1): Config = apply { suspend fun setSeed(
setViewingKeys(*DerivationTool.deriveUnifiedViewingKeys(seed, network, numberOfAccounts)) seed: ByteArray,
} network: ZcashNetwork,
numberOfAccounts: Int = 1
): Config =
apply {
setViewingKeys(
*DerivationTool.deriveUnifiedViewingKeys(
seed,
network,
numberOfAccounts
)
)
}
/** /**
* Sets the network from a network id, throwing an exception if the id is not recognized. * Sets the network from a network id, throwing an exception if the id is not recognized.
@ -338,16 +294,89 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
private fun validateViewingKeys() { private fun validateViewingKeys() {
require(viewingKeys.isNotEmpty()) { require(viewingKeys.isNotEmpty()) {
"Unified Viewing keys are required. Ensure that the unified viewing keys or seed" + "Unified Viewing keys are required. Ensure that the unified viewing keys or seed" +
" have been set on this Initializer." " have been set on this Initializer."
} }
viewingKeys.forEach { viewingKeys.forEach {
DerivationTool.validateUnifiedViewingKey(it) DerivationTool.validateUnifiedViewingKey(it)
} }
} }
} }
companion object : SdkSynchronizer.Erasable { companion object : SdkSynchronizer.Erasable {
suspend fun new(appContext: Context, config: Config) = new(appContext, null, config)
suspend fun new(
appContext: Context,
onCriticalErrorHandler: ((Throwable?) -> Boolean)? = null,
block: (Config) -> Unit
) = new(appContext, onCriticalErrorHandler, Config(block))
suspend fun new(
context: Context,
onCriticalErrorHandler: ((Throwable?) -> Boolean)?,
config: Config
): Initializer {
config.validate()
val heightToUse = config.birthdayHeight
?: (if (config.defaultToOldestHeight == true) config.network.saplingActivationHeight else null)
val loadedBirthday =
WalletBirthdayTool.loadNearest(context, config.network, heightToUse)
val rustBackend = initRustBackend(context, config.network, config.alias, loadedBirthday)
return Initializer(
context.applicationContext,
rustBackend,
config.network,
config.alias,
config.host,
config.port,
config.viewingKeys,
config.overwriteVks,
loadedBirthday
)
}
private fun onCriticalError(onCriticalErrorHandler: ((Throwable?) -> Boolean)?, error: Throwable) {
twig("********")
twig("******** INITIALIZER ERROR: $error")
if (error.cause != null) twig("******** caused by ${error.cause}")
if (error.cause?.cause != null) twig("******** caused by ${error.cause?.cause}")
twig("********")
twig(error)
if (onCriticalErrorHandler == null) {
twig(
"WARNING: a critical error occurred on the Initializer but no callback is " +
"registered to be notified of critical errors! THIS IS PROBABLY A MISTAKE. To " +
"respond to these errors (perhaps to update the UI or alert the user) set " +
"initializer.onCriticalErrorHandler to a non-null value or use the secondary " +
"constructor: Initializer(context, handler) { ... }. Note that the synchronizer " +
"and initializer BOTH have error handlers and since the initializer exists " +
"before the synchronizer, it needs its error handler set separately."
)
}
onCriticalErrorHandler?.invoke(error)
}
private suspend fun initRustBackend(
context: Context,
network: ZcashNetwork,
alias: String,
birthday: WalletBirthday
): RustBackend {
return RustBackend.init(
cacheDbPath(context, network, alias),
dataDbPath(context, network, alias),
File(context.getCacheDirSuspend(), "params").absolutePath,
network,
birthday.height
)
}
/** /**
* Delete the databases associated with this wallet. This removes all compact blocks and * Delete the databases associated with this wallet. This removes all compact blocks and
* data derived from those blocks. For most wallets, this should not result in a loss of * data derived from those blocks. For most wallets, this should not result in a loss of
@ -362,7 +391,11 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* @return true when one of the associated files was found. False most likely indicates * @return true when one of the associated files was found. False most likely indicates
* that the wrong alias was provided. * that the wrong alias was provided.
*/ */
override fun erase(appContext: Context, network: ZcashNetwork, alias: String): Boolean { override suspend fun erase(
appContext: Context,
network: ZcashNetwork,
alias: String
): Boolean {
val cacheDeleted = deleteDb(cacheDbPath(appContext, network, alias)) val cacheDeleted = deleteDb(cacheDbPath(appContext, network, alias))
val dataDeleted = deleteDb(dataDbPath(appContext, network, alias)) val dataDeleted = deleteDb(dataDbPath(appContext, network, alias))
return dataDeleted || cacheDeleted return dataDeleted || cacheDeleted
@ -379,7 +412,11 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* @param network the network associated with the data in the cache database. * @param network the network associated with the data in the cache database.
* @param alias the alias to convert into a database path * @param alias the alias to convert into a database path
*/ */
internal fun cacheDbPath(appContext: Context, network: ZcashNetwork, alias: String): String = private suspend fun cacheDbPath(
appContext: Context,
network: ZcashNetwork,
alias: String
): String =
aliasToPath(appContext, network, alias, ZcashSdk.DB_CACHE_NAME) aliasToPath(appContext, network, alias, ZcashSdk.DB_CACHE_NAME)
/** /**
@ -388,12 +425,21 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* * @param network the network associated with the data in the database. * * @param network the network associated with the data in the database.
* @param alias the alias to convert into a database path * @param alias the alias to convert into a database path
*/ */
internal fun dataDbPath(appContext: Context, network: ZcashNetwork, alias: String): String = private suspend fun dataDbPath(
appContext: Context,
network: ZcashNetwork,
alias: String
): String =
aliasToPath(appContext, network, alias, ZcashSdk.DB_DATA_NAME) aliasToPath(appContext, network, alias, ZcashSdk.DB_DATA_NAME)
private fun aliasToPath(appContext: Context, network: ZcashNetwork, alias: String, dbFileName: String): String { private suspend fun aliasToPath(
appContext: Context,
network: ZcashNetwork,
alias: String,
dbFileName: String
): String {
val parentDir: String = val parentDir: String =
appContext.getDatabasePath("unused.db").parentFile?.absolutePath appContext.getDatabasePathSuspend("unused.db").parentFile?.absolutePath
?: throw InitializerException.DatabasePathException ?: throw InitializerException.DatabasePathException
val prefix = if (alias.endsWith('_')) alias else "${alias}_" val prefix = if (alias.endsWith('_')) alias else "${alias}_"
return File(parentDir, "$prefix${network.networkName}_$dbFileName").absolutePath return File(parentDir, "$prefix${network.networkName}_$dbFileName").absolutePath
@ -405,9 +451,10 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* @param filePath the path of the db to erase. * @param filePath the path of the db to erase.
* @return true when a file exists at the given path and was deleted. * @return true when a file exists at the given path and was deleted.
*/ */
private fun deleteDb(filePath: String): Boolean { private suspend fun deleteDb(filePath: String): Boolean {
// just try the journal file. Doesn't matter if it's not there. // just try the journal file. Doesn't matter if it's not there.
delete("$filePath-journal") delete("$filePath-journal")
return delete(filePath) return delete(filePath)
} }
@ -417,14 +464,16 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* @param filePath the path of the file to erase. * @param filePath the path of the file to erase.
* @return true when a file exists at the given path and was deleted. * @return true when a file exists at the given path and was deleted.
*/ */
private fun delete(filePath: String): Boolean { private suspend fun delete(filePath: String): Boolean {
return File(filePath).let { return File(filePath).let {
if (it.exists()) { withContext(SdkDispatchers.IO) {
twig("Deleting ${it.name}!") if (it.exists()) {
it.delete() twig("Deleting ${it.name}!")
true it.delete()
} else { true
false } else {
false
}
} }
} }
} }
@ -445,9 +494,9 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
internal fun validateAlias(alias: String) { internal fun validateAlias(alias: String) {
require( require(
alias.length in 1..99 && alias[0].isLetter() && alias.length in 1..99 && alias[0].isLetter() &&
alias.all { it.isLetterOrDigit() || it == '_' } alias.all { it.isLetterOrDigit() || it == '_' }
) { ) {
"ERROR: Invalid alias ($alias). For security, the alias must be shorter than 100 " + "ERROR: Invalid alias ($alias). For security, the alias must be shorter than 100 " +
"characters and only contain letters, digits or underscores and start with a letter." "characters and only contain letters, digits or underscores and start with a letter."
} }
} }

View File

@ -58,7 +58,6 @@ import io.grpc.ManagedChannel
import kotlinx.coroutines.CoroutineExceptionHandler import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Dispatchers.IO
import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
@ -247,7 +246,7 @@ class SdkSynchronizer internal constructor(
override val latestBirthdayHeight: Int get() = processor.birthdayHeight override val latestBirthdayHeight: Int get() = processor.birthdayHeight
override fun prepare(): Synchronizer = apply { override suspend fun prepare(): Synchronizer = apply {
storage.prepare() storage.prepare()
} }
@ -336,15 +335,15 @@ class SdkSynchronizer internal constructor(
// TODO: turn this section into the data access API. For now, just aggregate all the things that we want to do with the underlying data // TODO: turn this section into the data access API. For now, just aggregate all the things that we want to do with the underlying data
fun findBlockHash(height: Int): ByteArray? { suspend fun findBlockHash(height: Int): ByteArray? {
return (storage as? PagedTransactionRepository)?.findBlockHash(height) return (storage as? PagedTransactionRepository)?.findBlockHash(height)
} }
fun findBlockHashAsHex(height: Int): String? { suspend fun findBlockHashAsHex(height: Int): String? {
return findBlockHash(height)?.toHexReversed() return findBlockHash(height)?.toHexReversed()
} }
fun getTransactionCount(): Int { suspend fun getTransactionCount(): Int {
return (storage as? PagedTransactionRepository)?.getTransactionCount() ?: 0 return (storage as? PagedTransactionRepository)?.getTransactionCount() ?: 0
} }
@ -530,7 +529,7 @@ class SdkSynchronizer internal constructor(
} }
} }
private suspend fun refreshPendingTransactions() = withContext(IO) { private suspend fun refreshPendingTransactions() = withContext(Dispatchers.IO) {
twig("[cleanup] beginning to refresh and clean up pending transactions") twig("[cleanup] beginning to refresh and clean up pending transactions")
// TODO: this would be the place to clear out any stale pending transactions. Remove filter // TODO: this would be the place to clear out any stale pending transactions. Remove filter
// logic and then delete any pending transaction with sufficient confirmations (all in one // logic and then delete any pending transaction with sufficient confirmations (all in one
@ -737,7 +736,7 @@ class SdkSynchronizer internal constructor(
* *
* @return true when content was found for the given alias. False otherwise. * @return true when content was found for the given alias. False otherwise.
*/ */
fun erase(appContext: Context, network: ZcashNetwork, alias: String = ZcashSdk.DEFAULT_ALIAS): Boolean suspend fun erase(appContext: Context, network: ZcashNetwork, alias: String = ZcashSdk.DEFAULT_ALIAS): Boolean
} }
} }

View File

@ -35,7 +35,7 @@ interface Synchronizer {
* where setup and maintenance can occur for various Synchronizers. One that uses a database * where setup and maintenance can occur for various Synchronizers. One that uses a database
* would take this opportunity to do data migrations or key migrations. * would take this opportunity to do data migrations or key migrations.
*/ */
fun prepare(): Synchronizer suspend fun prepare(): Synchronizer
/** /**
* Starts this synchronizer within the given scope. * Starts this synchronizer within the given scope.

View File

@ -526,7 +526,7 @@ class CompactBlockProcessor(
* @return [ERROR_CODE_NONE] when there is no problem. Otherwise, return the lowest height where an error was * @return [ERROR_CODE_NONE] when there is no problem. Otherwise, return the lowest height where an error was
* found. In other words, validation starts at the back of the chain and works toward the tip. * found. In other words, validation starts at the back of the chain and works toward the tip.
*/ */
private fun validateNewBlocks(range: IntRange?): Int { private suspend fun validateNewBlocks(range: IntRange?): Int {
if (range?.isEmpty() != false) { if (range?.isEmpty() != false) {
twig("no blocks to validate: $range") twig("no blocks to validate: $range")
return ERROR_CODE_NONE return ERROR_CODE_NONE

View File

@ -0,0 +1,14 @@
package cash.z.ecc.android.sdk.internal
import kotlinx.coroutines.asCoroutineDispatcher
import java.util.concurrent.Executors
internal object SdkDispatchers {
/*
* Based on internal discussion, keep the SDK internals confined to a single IO thread.
*
* We don't expect things to break, but we don't have the WAL enabled for SQLite so this
* is a simple solution.
*/
val IO = Executors.newSingleThreadExecutor().asCoroutineDispatcher()
}

View File

@ -6,8 +6,9 @@ import androidx.room.RoomDatabase
import cash.z.ecc.android.sdk.internal.db.CompactBlockDao import cash.z.ecc.android.sdk.internal.db.CompactBlockDao
import cash.z.ecc.android.sdk.internal.db.CompactBlockDb import cash.z.ecc.android.sdk.internal.db.CompactBlockDb
import cash.z.ecc.android.sdk.db.entity.CompactBlockEntity import cash.z.ecc.android.sdk.db.entity.CompactBlockEntity
import cash.z.ecc.android.sdk.internal.SdkDispatchers
import cash.z.wallet.sdk.rpc.CompactFormats import cash.z.wallet.sdk.rpc.CompactFormats
import kotlinx.coroutines.Dispatchers.IO import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
/** /**
@ -38,7 +39,7 @@ class CompactBlockDbStore(
.build() .build()
} }
override suspend fun getLatestHeight(): Int = withContext(IO) { override suspend fun getLatestHeight(): Int = withContext(SdkDispatchers.IO) {
Math.max(0, cacheDao.latestBlockHeight()) Math.max(0, cacheDao.latestBlockHeight())
} }
@ -46,15 +47,17 @@ class CompactBlockDbStore(
return cacheDao.findCompactBlock(height)?.let { CompactFormats.CompactBlock.parseFrom(it) } return cacheDao.findCompactBlock(height)?.let { CompactFormats.CompactBlock.parseFrom(it) }
} }
override suspend fun write(result: List<CompactFormats.CompactBlock>) = withContext(IO) { override suspend fun write(result: List<CompactFormats.CompactBlock>) = withContext(SdkDispatchers.IO) {
cacheDao.insert(result.map { CompactBlockEntity(it.height.toInt(), it.toByteArray()) }) cacheDao.insert(result.map { CompactBlockEntity(it.height.toInt(), it.toByteArray()) })
} }
override suspend fun rewindTo(height: Int) = withContext(IO) { override suspend fun rewindTo(height: Int) = withContext(SdkDispatchers.IO) {
cacheDao.rewindTo(height) cacheDao.rewindTo(height)
} }
override fun close() { override suspend fun close() {
cacheDb.close() withContext(SdkDispatchers.IO) {
cacheDb.close()
}
} }
} }

View File

@ -7,6 +7,7 @@ import cash.z.ecc.android.sdk.internal.service.LightWalletService
import cash.z.wallet.sdk.rpc.Service import cash.z.wallet.sdk.rpc.Service
import io.grpc.StatusRuntimeException import io.grpc.StatusRuntimeException
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Dispatchers.IO import kotlinx.coroutines.Dispatchers.IO
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -122,8 +123,10 @@ open class CompactBlockDownloader private constructor(val compactBlockStore: Com
/** /**
* Stop this downloader and cleanup any resources being used. * Stop this downloader and cleanup any resources being used.
*/ */
fun stop() { suspend fun stop() {
lightWalletService.shutdown() withContext(Dispatchers.IO) {
lightWalletService.shutdown()
}
compactBlockStore.close() compactBlockStore.close()
} }

View File

@ -37,5 +37,5 @@ interface CompactBlockStore {
/** /**
* Close any connections to the block store. * Close any connections to the block store.
*/ */
fun close() suspend fun close()
} }

View File

@ -0,0 +1,11 @@
package cash.z.ecc.android.sdk.internal.ext
import android.content.Context
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
suspend fun Context.getDatabasePathSuspend(fileName: String) =
withContext(Dispatchers.IO) { getDatabasePath(fileName) }
suspend fun Context.getCacheDirSuspend() =
withContext(Dispatchers.IO) { cacheDir }

View File

@ -0,0 +1,7 @@
package cash.z.ecc.android.sdk.internal.ext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File
suspend fun File.deleteSuspend() = withContext(Dispatchers.IO) { delete() }

View File

@ -16,10 +16,12 @@ import cash.z.ecc.android.sdk.internal.ext.android.toFlowPagedList
import cash.z.ecc.android.sdk.internal.ext.android.toRefreshable import cash.z.ecc.android.sdk.internal.ext.android.toRefreshable
import cash.z.ecc.android.sdk.internal.ext.tryWarn import cash.z.ecc.android.sdk.internal.ext.tryWarn
import cash.z.ecc.android.sdk.internal.twig import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.internal.SdkDispatchers
import cash.z.ecc.android.sdk.jni.RustBackend import cash.z.ecc.android.sdk.jni.RustBackend
import cash.z.ecc.android.sdk.type.UnifiedAddressAccount import cash.z.ecc.android.sdk.type.UnifiedAddressAccount
import cash.z.ecc.android.sdk.type.UnifiedViewingKey import cash.z.ecc.android.sdk.type.UnifiedViewingKey
import cash.z.ecc.android.sdk.type.WalletBirthday import cash.z.ecc.android.sdk.type.WalletBirthday
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Dispatchers.IO import kotlinx.coroutines.Dispatchers.IO
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.emitAll import kotlinx.coroutines.flow.emitAll
@ -95,7 +97,7 @@ class PagedTransactionRepository(
override suspend fun getAccountCount(): Int = lazy.accounts.count() override suspend fun getAccountCount(): Int = lazy.accounts.count()
override fun prepare() { override suspend fun prepare() {
if (lazy.isPrepared.get()) { if (lazy.isPrepared.get()) {
twig("Warning: skipped the preparation step because we're already prepared!") twig("Warning: skipped the preparation step because we're already prepared!")
} else { } else {
@ -112,7 +114,7 @@ class PagedTransactionRepository(
* side because Rust was intended to own the "dataDb" and Kotlin just reads from it. Since then, * side because Rust was intended to own the "dataDb" and Kotlin just reads from it. Since then,
* it has been more clear that Kotlin should own the data and just let Rust use it. * it has been more clear that Kotlin should own the data and just let Rust use it.
*/ */
private fun initMissingDatabases() { private suspend fun initMissingDatabases() {
maybeCreateDataDb() maybeCreateDataDb()
maybeInitBlocksTable(birthday) maybeInitBlocksTable(birthday)
maybeInitAccountsTable(viewingKeys) maybeInitAccountsTable(viewingKeys)
@ -121,7 +123,7 @@ class PagedTransactionRepository(
/** /**
* Create the dataDb and its table, if it doesn't exist. * Create the dataDb and its table, if it doesn't exist.
*/ */
private fun maybeCreateDataDb() { private suspend fun maybeCreateDataDb() {
tryWarn("Warning: did not create dataDb. It probably already exists.") { tryWarn("Warning: did not create dataDb. It probably already exists.") {
rustBackend.initDataDb() rustBackend.initDataDb()
twig("Initialized wallet for first run file: ${rustBackend.pathDataDb}") twig("Initialized wallet for first run file: ${rustBackend.pathDataDb}")
@ -131,7 +133,7 @@ class PagedTransactionRepository(
/** /**
* Initialize the blocks table with the given birthday, if needed. * Initialize the blocks table with the given birthday, if needed.
*/ */
private fun maybeInitBlocksTable(birthday: WalletBirthday) { private suspend fun maybeInitBlocksTable(birthday: WalletBirthday) {
// TODO: consider converting these to typed exceptions in the welding layer // TODO: consider converting these to typed exceptions in the welding layer
tryWarn( tryWarn(
"Warning: did not initialize the blocks table. It probably was already initialized.", "Warning: did not initialize the blocks table. It probably was already initialized.",
@ -151,7 +153,7 @@ class PagedTransactionRepository(
/** /**
* Initialize the accounts table with the given viewing keys. * Initialize the accounts table with the given viewing keys.
*/ */
private fun maybeInitAccountsTable(viewingKeys: List<UnifiedViewingKey>) { private suspend fun maybeInitAccountsTable(viewingKeys: List<UnifiedViewingKey>) {
// TODO: consider converting these to typed exceptions in the welding layer // TODO: consider converting these to typed exceptions in the welding layer
tryWarn( tryWarn(
"Warning: did not initialize the accounts table. It probably was already initialized.", "Warning: did not initialize the accounts table. It probably was already initialized.",
@ -181,7 +183,7 @@ class PagedTransactionRepository(
} }
} }
private fun applyKeyMigrations() { private suspend fun applyKeyMigrations() {
if (overwriteVks) { if (overwriteVks) {
twig("applying key migrations . . .") twig("applying key migrations . . .")
maybeInitAccountsTable(viewingKeys) maybeInitAccountsTable(viewingKeys)
@ -191,8 +193,10 @@ class PagedTransactionRepository(
/** /**
* Close the underlying database. * Close the underlying database.
*/ */
fun close() { suspend fun close() {
lazy.db?.close() withContext(Dispatchers.IO) {
lazy.db?.close()
}
} }
// TODO: begin converting these into Data Access API. For now, just collect the desired operations and iterate/refactor, later // TODO: begin converting these into Data Access API. For now, just collect the desired operations and iterate/refactor, later

View File

@ -87,7 +87,7 @@ interface TransactionRepository {
suspend fun getAccountCount(): Int suspend fun getAccountCount(): Int
fun prepare() suspend fun prepare()
// //
// Transactions // Transactions

View File

@ -8,7 +8,8 @@ import cash.z.ecc.android.sdk.internal.twigTask
import cash.z.ecc.android.sdk.jni.RustBackend import cash.z.ecc.android.sdk.jni.RustBackend
import cash.z.ecc.android.sdk.jni.RustBackendWelding import cash.z.ecc.android.sdk.jni.RustBackendWelding
import cash.z.ecc.android.sdk.internal.SaplingParamTool import cash.z.ecc.android.sdk.internal.SaplingParamTool
import kotlinx.coroutines.Dispatchers.IO import cash.z.ecc.android.sdk.internal.SdkDispatchers
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
/** /**
@ -44,7 +45,7 @@ class WalletTransactionEncoder(
toAddress: String, toAddress: String,
memo: ByteArray?, memo: ByteArray?,
fromAccountIndex: Int fromAccountIndex: Int
): EncodedTransaction = withContext(IO) { ): EncodedTransaction = withContext(SdkDispatchers.IO) {
val transactionId = createSpend(spendingKey, zatoshi, toAddress, memo) val transactionId = createSpend(spendingKey, zatoshi, toAddress, memo)
repository.findEncodedTransactionById(transactionId) repository.findEncodedTransactionById(transactionId)
?: throw TransactionEncoderException.TransactionNotFoundException(transactionId) ?: throw TransactionEncoderException.TransactionNotFoundException(transactionId)
@ -54,7 +55,7 @@ class WalletTransactionEncoder(
spendingKey: String, spendingKey: String,
transparentSecretKey: String, transparentSecretKey: String,
memo: ByteArray? memo: ByteArray?
): EncodedTransaction = withContext(IO) { ): EncodedTransaction = withContext(SdkDispatchers.IO) {
val transactionId = createShieldingSpend(spendingKey, transparentSecretKey, memo) val transactionId = createShieldingSpend(spendingKey, transparentSecretKey, memo)
repository.findEncodedTransactionById(transactionId) repository.findEncodedTransactionById(transactionId)
?: throw TransactionEncoderException.TransactionNotFoundException(transactionId) ?: throw TransactionEncoderException.TransactionNotFoundException(transactionId)
@ -68,7 +69,7 @@ class WalletTransactionEncoder(
* *
* @return true when the given address is a valid z-addr * @return true when the given address is a valid z-addr
*/ */
override suspend fun isValidShieldedAddress(address: String): Boolean = withContext(IO) { override suspend fun isValidShieldedAddress(address: String): Boolean = withContext(Dispatchers.IO) {
rustBackend.isValidShieldedAddr(address) rustBackend.isValidShieldedAddr(address)
} }
@ -80,7 +81,7 @@ class WalletTransactionEncoder(
* *
* @return true when the given address is a valid t-addr * @return true when the given address is a valid t-addr
*/ */
override suspend fun isValidTransparentAddress(address: String): Boolean = withContext(IO) { override suspend fun isValidTransparentAddress(address: String): Boolean = withContext(Dispatchers.IO) {
rustBackend.isValidTransparentAddr(address) rustBackend.isValidTransparentAddr(address)
} }
@ -110,7 +111,7 @@ class WalletTransactionEncoder(
toAddress: String, toAddress: String,
memo: ByteArray? = byteArrayOf(), memo: ByteArray? = byteArrayOf(),
fromAccountIndex: Int = 0 fromAccountIndex: Int = 0
): Long = withContext(IO) { ): Long = withContext(Dispatchers.IO) {
twigTask( twigTask(
"creating transaction to spend $zatoshi zatoshi to" + "creating transaction to spend $zatoshi zatoshi to" +
" ${toAddress.masked()} with memo $memo" " ${toAddress.masked()} with memo $memo"
@ -140,7 +141,7 @@ class WalletTransactionEncoder(
spendingKey: String, spendingKey: String,
transparentSecretKey: String, transparentSecretKey: String,
memo: ByteArray? = byteArrayOf() memo: ByteArray? = byteArrayOf()
): Long = withContext(IO) { ): Long = withContext(Dispatchers.IO) {
twigTask("creating transaction to shield all UTXOs") { twigTask("creating transaction to shield all UTXOs") {
try { try {
SaplingParamTool.ensureParams((rustBackend as RustBackend).pathParamsDir) SaplingParamTool.ensureParams((rustBackend as RustBackend).pathParamsDir)

View File

@ -0,0 +1,44 @@
package cash.z.ecc.android.sdk.jni
import cash.z.ecc.android.sdk.internal.twig
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import java.util.concurrent.atomic.AtomicBoolean
/**
* Loads a native library once. This class is thread-safe.
*
* To use this class, create a singleton instance for each given [libraryName].
*
* @param libraryName Name of the library to load.
*/
internal class NativeLibraryLoader(private val libraryName: String) {
private val isLoaded = AtomicBoolean(false)
private val mutex = Mutex()
suspend fun load() {
// Double-checked locking to avoid the Mutex unless necessary, as the hot path is
// for the library to be loaded since this should only run once for the lifetime
// of the application
if (!isLoaded.get()) {
mutex.withLock {
if (!isLoaded.get()) {
loadRustLibrary()
}
}
}
}
private suspend fun loadRustLibrary() {
try {
withContext(Dispatchers.IO) {
twig("Loading native library $libraryName") { System.loadLibrary(libraryName) }
}
isLoaded.set(true)
} catch (e: Throwable) {
twig("Error while loading native library: ${e.message}")
}
}
}

View File

@ -4,10 +4,14 @@ import cash.z.ecc.android.sdk.exception.BirthdayException
import cash.z.ecc.android.sdk.ext.ZcashSdk.OUTPUT_PARAM_FILE_NAME import cash.z.ecc.android.sdk.ext.ZcashSdk.OUTPUT_PARAM_FILE_NAME
import cash.z.ecc.android.sdk.ext.ZcashSdk.SPEND_PARAM_FILE_NAME import cash.z.ecc.android.sdk.ext.ZcashSdk.SPEND_PARAM_FILE_NAME
import cash.z.ecc.android.sdk.internal.twig import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.internal.SdkDispatchers
import cash.z.ecc.android.sdk.internal.ext.deleteSuspend
import cash.z.ecc.android.sdk.tool.DerivationTool import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.UnifiedViewingKey import cash.z.ecc.android.sdk.type.UnifiedViewingKey
import cash.z.ecc.android.sdk.type.WalletBalance import cash.z.ecc.android.sdk.type.WalletBalance
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File import java.io.File
/** /**
@ -17,10 +21,6 @@ import java.io.File
*/ */
class RustBackend private constructor() : RustBackendWelding { class RustBackend private constructor() : RustBackendWelding {
init {
load()
}
// Paths // Paths
lateinit var pathDataDb: String lateinit var pathDataDb: String
internal set internal set
@ -35,14 +35,14 @@ class RustBackend private constructor() : RustBackendWelding {
get() = if (field != -1) field else throw BirthdayException.UninitializedBirthdayException get() = if (field != -1) field else throw BirthdayException.UninitializedBirthdayException
private set private set
fun clear(clearCacheDb: Boolean = true, clearDataDb: Boolean = true) { suspend fun clear(clearCacheDb: Boolean = true, clearDataDb: Boolean = true) {
if (clearCacheDb) { if (clearCacheDb) {
twig("Deleting the cache database!") twig("Deleting the cache database!")
File(pathCacheDb).delete() File(pathCacheDb).deleteSuspend()
} }
if (clearDataDb) { if (clearDataDb) {
twig("Deleting the data database!") twig("Deleting the data database!")
File(pathDataDb).delete() File(pathDataDb).deleteSuspend()
} }
} }
@ -50,19 +50,31 @@ class RustBackend private constructor() : RustBackendWelding {
// Wrapper Functions // Wrapper Functions
// //
override fun initDataDb() = initDataDb(pathDataDb, networkId = network.id) override suspend fun initDataDb() = withContext(SdkDispatchers.IO) {
initDataDb(
pathDataDb,
networkId = network.id
)
}
override fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean { override suspend fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean {
val extfvks = Array(keys.size) { "" } val extfvks = Array(keys.size) { "" }
val extpubs = Array(keys.size) { "" } val extpubs = Array(keys.size) { "" }
keys.forEachIndexed { i, key -> keys.forEachIndexed { i, key ->
extfvks[i] = key.extfvk extfvks[i] = key.extfvk
extpubs[i] = key.extpub extpubs[i] = key.extpub
} }
return initAccountsTableWithKeys(pathDataDb, extfvks, extpubs, networkId = network.id) return withContext(SdkDispatchers.IO) {
initAccountsTableWithKeys(
pathDataDb,
extfvks,
extpubs,
networkId = network.id
)
}
} }
override fun initAccountsTable( override suspend fun initAccountsTable(
seed: ByteArray, seed: ByteArray,
numberOfAccounts: Int numberOfAccounts: Int
): Array<UnifiedViewingKey> { ): Array<UnifiedViewingKey> {
@ -71,82 +83,131 @@ class RustBackend private constructor() : RustBackendWelding {
} }
} }
override fun initBlocksTable( override suspend fun initBlocksTable(
height: Int, height: Int,
hash: String, hash: String,
time: Long, time: Long,
saplingTree: String saplingTree: String
): Boolean { ): Boolean {
return initBlocksTable(pathDataDb, height, hash, time, saplingTree, networkId = network.id) return withContext(SdkDispatchers.IO) {
initBlocksTable(
pathDataDb,
height,
hash,
time,
saplingTree,
networkId = network.id
)
}
} }
override fun getShieldedAddress(account: Int) = getShieldedAddress(pathDataDb, account, networkId = network.id) override suspend fun getShieldedAddress(account: Int) = withContext(SdkDispatchers.IO) {
getShieldedAddress(
pathDataDb,
account,
networkId = network.id
)
}
override fun getTransparentAddress(account: Int, index: Int): String { override suspend fun getTransparentAddress(account: Int, index: Int): String {
throw NotImplementedError("TODO: implement this at the zcash_client_sqlite level. But for now, use DerivationTool, instead to derive addresses from seeds") throw NotImplementedError("TODO: implement this at the zcash_client_sqlite level. But for now, use DerivationTool, instead to derive addresses from seeds")
} }
override fun getBalance(account: Int) = getBalance(pathDataDb, account, networkId = network.id) override suspend fun getBalance(account: Int) = withContext(SdkDispatchers.IO) {
getBalance(
pathDataDb,
account,
networkId = network.id
)
}
override fun getVerifiedBalance(account: Int) = getVerifiedBalance(pathDataDb, account, networkId = network.id) override suspend fun getVerifiedBalance(account: Int) = withContext(SdkDispatchers.IO) {
getVerifiedBalance(
pathDataDb,
account,
networkId = network.id
)
}
override fun getReceivedMemoAsUtf8(idNote: Long) = override suspend fun getReceivedMemoAsUtf8(idNote: Long) =
getReceivedMemoAsUtf8(pathDataDb, idNote, networkId = network.id) withContext(SdkDispatchers.IO) { getReceivedMemoAsUtf8(pathDataDb, idNote, networkId = network.id) }
override fun getSentMemoAsUtf8(idNote: Long) = getSentMemoAsUtf8(pathDataDb, idNote, networkId = network.id) override suspend fun getSentMemoAsUtf8(idNote: Long) = withContext(SdkDispatchers.IO) {
getSentMemoAsUtf8(
pathDataDb,
idNote,
networkId = network.id
)
}
override fun validateCombinedChain() = validateCombinedChain(pathCacheDb, pathDataDb, networkId = network.id,) override suspend fun validateCombinedChain() = withContext(SdkDispatchers.IO) {
validateCombinedChain(
pathCacheDb,
pathDataDb,
networkId = network.id,
)
}
override fun getNearestRewindHeight(height: Int): Int = getNearestRewindHeight(pathDataDb, height, networkId = network.id) override suspend fun getNearestRewindHeight(height: Int): Int = withContext(SdkDispatchers.IO) {
getNearestRewindHeight(
pathDataDb,
height,
networkId = network.id
)
}
/** /**
* Deletes data for all blocks above the given height. Boils down to: * Deletes data for all blocks above the given height. Boils down to:
* *
* DELETE FROM blocks WHERE height > ? * DELETE FROM blocks WHERE height > ?
*/ */
override fun rewindToHeight(height: Int) = rewindToHeight(pathDataDb, height, networkId = network.id) override suspend fun rewindToHeight(height: Int) =
withContext(SdkDispatchers.IO) { rewindToHeight(pathDataDb, height, networkId = network.id) }
override fun scanBlocks(limit: Int): Boolean { override suspend fun scanBlocks(limit: Int): Boolean {
return if (limit > 0) { return if (limit > 0) {
scanBlockBatch(pathCacheDb, pathDataDb, limit, networkId = network.id) withContext(SdkDispatchers.IO) {
scanBlockBatch(
pathCacheDb,
pathDataDb,
limit,
networkId = network.id
)
}
} else { } else {
scanBlocks(pathCacheDb, pathDataDb, networkId = network.id) withContext(SdkDispatchers.IO) {
scanBlocks(
pathCacheDb,
pathDataDb,
networkId = network.id
)
}
} }
} }
override fun decryptAndStoreTransaction(tx: ByteArray) = decryptAndStoreTransaction(pathDataDb, tx, networkId = network.id) override suspend fun decryptAndStoreTransaction(tx: ByteArray) = withContext(SdkDispatchers.IO) {
decryptAndStoreTransaction(
pathDataDb,
tx,
networkId = network.id
)
}
override fun createToAddress( override suspend fun createToAddress(
consensusBranchId: Long, consensusBranchId: Long,
account: Int, account: Int,
extsk: String, extsk: String,
to: String, to: String,
value: Long, value: Long,
memo: ByteArray? memo: ByteArray?
): Long = createToAddress( ): Long = withContext(SdkDispatchers.IO) {
pathDataDb, createToAddress(
consensusBranchId,
account,
extsk,
to,
value,
memo ?: ByteArray(0),
"$pathParamsDir/$SPEND_PARAM_FILE_NAME",
"$pathParamsDir/$OUTPUT_PARAM_FILE_NAME",
networkId = network.id,
)
override fun shieldToAddress(
extsk: String,
tsk: String,
memo: ByteArray?
): Long {
twig("TMP: shieldToAddress with db path: $pathDataDb, ${memo?.size}")
return shieldToAddress(
pathDataDb, pathDataDb,
0, consensusBranchId,
account,
extsk, extsk,
tsk, to,
value,
memo ?: ByteArray(0), memo ?: ByteArray(0),
"$pathParamsDir/$SPEND_PARAM_FILE_NAME", "$pathParamsDir/$SPEND_PARAM_FILE_NAME",
"$pathParamsDir/$OUTPUT_PARAM_FILE_NAME", "$pathParamsDir/$OUTPUT_PARAM_FILE_NAME",
@ -154,31 +215,84 @@ class RustBackend private constructor() : RustBackendWelding {
) )
} }
override fun putUtxo( override suspend fun shieldToAddress(
extsk: String,
tsk: String,
memo: ByteArray?
): Long {
twig("TMP: shieldToAddress with db path: $pathDataDb, ${memo?.size}")
return withContext(SdkDispatchers.IO) {
shieldToAddress(
pathDataDb,
0,
extsk,
tsk,
memo ?: ByteArray(0),
"$pathParamsDir/$SPEND_PARAM_FILE_NAME",
"$pathParamsDir/$OUTPUT_PARAM_FILE_NAME",
networkId = network.id,
)
}
}
override suspend fun putUtxo(
tAddress: String, tAddress: String,
txId: ByteArray, txId: ByteArray,
index: Int, index: Int,
script: ByteArray, script: ByteArray,
value: Long, value: Long,
height: Int height: Int
): Boolean = putUtxo(pathDataDb, tAddress, txId, index, script, value, height, networkId = network.id) ): Boolean = withContext(SdkDispatchers.IO) {
putUtxo(
pathDataDb,
tAddress,
txId,
index,
script,
value,
height,
networkId = network.id
)
}
override fun clearUtxos( override suspend fun clearUtxos(
tAddress: String, tAddress: String,
aboveHeight: Int, aboveHeight: Int,
): Boolean = clearUtxos(pathDataDb, tAddress, aboveHeight, networkId = network.id) ): Boolean = withContext(SdkDispatchers.IO) {
clearUtxos(
pathDataDb,
tAddress,
aboveHeight,
networkId = network.id
)
}
override fun getDownloadedUtxoBalance(address: String): WalletBalance { override suspend fun getDownloadedUtxoBalance(address: String): WalletBalance {
val verified = getVerifiedTransparentBalance(pathDataDb, address, networkId = network.id) val verified = withContext(SdkDispatchers.IO) {
val total = getTotalTransparentBalance(pathDataDb, address, networkId = network.id) getVerifiedTransparentBalance(
pathDataDb,
address,
networkId = network.id
)
}
val total = withContext(SdkDispatchers.IO) {
getTotalTransparentBalance(
pathDataDb,
address,
networkId = network.id
)
}
return WalletBalance(total, verified) return WalletBalance(total, verified)
} }
override fun isValidShieldedAddr(addr: String) = isValidShieldedAddress(addr, networkId = network.id) override fun isValidShieldedAddr(addr: String) =
isValidShieldedAddress(addr, networkId = network.id)
override fun isValidTransparentAddr(addr: String) = isValidTransparentAddress(addr, networkId = network.id) override fun isValidTransparentAddr(addr: String) =
isValidTransparentAddress(addr, networkId = network.id)
override fun getBranchIdForHeight(height: Int): Long = branchIdForHeight(height, networkId = network.id) override fun getBranchIdForHeight(height: Int): Long =
branchIdForHeight(height, networkId = network.id)
// /** // /**
// * This is a proof-of-concept for doing Local RPC, where we are effectively using the JNI // * This is a proof-of-concept for doing Local RPC, where we are effectively using the JNI
@ -203,19 +317,21 @@ class RustBackend private constructor() : RustBackendWelding {
* Exposes all of the librustzcash functions along with helpers for loading the static library. * Exposes all of the librustzcash functions along with helpers for loading the static library.
*/ */
companion object { companion object {
private var loaded = false internal val rustLibraryLoader = NativeLibraryLoader("zcashwalletsdk")
/** /**
* Loads the library and initializes path variables. Although it is best to only call this * Loads the library and initializes path variables. Although it is best to only call this
* function once, it is idempotent. * function once, it is idempotent.
*/ */
fun init( suspend fun init(
cacheDbPath: String, cacheDbPath: String,
dataDbPath: String, dataDbPath: String,
paramsPath: String, paramsPath: String,
zcashNetwork: ZcashNetwork, zcashNetwork: ZcashNetwork,
birthdayHeight: Int? = null birthdayHeight: Int? = null
): RustBackend { ): RustBackend {
rustLibraryLoader.load()
return RustBackend().apply { return RustBackend().apply {
pathCacheDb = cacheDbPath pathCacheDb = cacheDbPath
pathDataDb = dataDbPath pathDataDb = dataDbPath
@ -227,16 +343,6 @@ class RustBackend private constructor() : RustBackendWelding {
} }
} }
fun load() {
// It is safe to call these things twice but not efficient. So we add a loose check and
// ignore the fact that it's not thread-safe.
if (!loaded) {
twig("Loading RustBackend") {
loadRustLibrary()
}
}
}
/** /**
* Forwards Rust logs to logcat. This is a function that is intended for debug purposes. All * Forwards Rust logs to logcat. This is a function that is intended for debug purposes. All
* logs will be tagged with `cash.z.rust.logs`. Typically, a developer would clone * logs will be tagged with `cash.z.rust.logs`. Typically, a developer would clone
@ -249,33 +355,24 @@ class RustBackend private constructor() : RustBackendWelding {
*/ */
fun enableRustLogs() = initLogs() fun enableRustLogs() = initLogs()
/**
* The first call made to this object in order to load the Rust backend library. All other
* external function calls will fail if the libraries have not been loaded.
*/
private fun loadRustLibrary() {
try {
System.loadLibrary("zcashwalletsdk")
loaded = true
} catch (e: Throwable) {
twig("Error while loading native library: ${e.message}")
}
}
// //
// External Functions // External Functions
// //
@JvmStatic private external fun initDataDb(dbDataPath: String, networkId: Int): Boolean @JvmStatic
private external fun initDataDb(dbDataPath: String, networkId: Int): Boolean
@JvmStatic private external fun initAccountsTableWithKeys( @JvmStatic
private external fun initAccountsTableWithKeys(
dbDataPath: String, dbDataPath: String,
extfvk: Array<out String>, extfvk: Array<out String>,
extpub: Array<out String>, extpub: Array<out String>,
networkId: Int, networkId: Int,
): Boolean ): Boolean
@JvmStatic private external fun initBlocksTable( @JvmStatic
private external fun initBlocksTable(
dbDataPath: String, dbDataPath: String,
height: Int, height: Int,
hash: String, hash: String,
@ -364,7 +461,8 @@ class RustBackend private constructor() : RustBackendWelding {
networkId: Int, networkId: Int,
) )
@JvmStatic private external fun createToAddress( @JvmStatic
private external fun createToAddress(
dbDataPath: String, dbDataPath: String,
consensusBranchId: Long, consensusBranchId: Long,
account: Int, account: Int,
@ -377,7 +475,8 @@ class RustBackend private constructor() : RustBackendWelding {
networkId: Int, networkId: Int,
): Long ): Long
@JvmStatic private external fun shieldToAddress( @JvmStatic
private external fun shieldToAddress(
dbDataPath: String, dbDataPath: String,
account: Int, account: Int,
extsk: String, extsk: String,
@ -388,11 +487,14 @@ class RustBackend private constructor() : RustBackendWelding {
networkId: Int, networkId: Int,
): Long ): Long
@JvmStatic private external fun initLogs() @JvmStatic
private external fun initLogs()
@JvmStatic private external fun branchIdForHeight(height: Int, networkId: Int): Long @JvmStatic
private external fun branchIdForHeight(height: Int, networkId: Int): Long
@JvmStatic private external fun putUtxo( @JvmStatic
private external fun putUtxo(
dbDataPath: String, dbDataPath: String,
tAddress: String, tAddress: String,
txId: ByteArray, txId: ByteArray,
@ -403,23 +505,27 @@ class RustBackend private constructor() : RustBackendWelding {
networkId: Int, networkId: Int,
): Boolean ): Boolean
@JvmStatic private external fun clearUtxos( @JvmStatic
private external fun clearUtxos(
dbDataPath: String, dbDataPath: String,
tAddress: String, tAddress: String,
aboveHeight: Int, aboveHeight: Int,
networkId: Int, networkId: Int,
): Boolean ): Boolean
@JvmStatic private external fun getVerifiedTransparentBalance( @JvmStatic
private external fun getVerifiedTransparentBalance(
pathDataDb: String, pathDataDb: String,
taddr: String, taddr: String,
networkId: Int, networkId: Int,
): Long ): Long
@JvmStatic private external fun getTotalTransparentBalance( @JvmStatic
private external fun getTotalTransparentBalance(
pathDataDb: String, pathDataDb: String,
taddr: String, taddr: String,
networkId: Int, networkId: Int,
): Long ): Long
} }
} }

View File

@ -14,7 +14,7 @@ interface RustBackendWelding {
val network: ZcashNetwork val network: ZcashNetwork
fun createToAddress( suspend fun createToAddress(
consensusBranchId: Long, consensusBranchId: Long,
account: Int, account: Int,
extsk: String, extsk: String,
@ -23,51 +23,51 @@ interface RustBackendWelding {
memo: ByteArray? = byteArrayOf() memo: ByteArray? = byteArrayOf()
): Long ): Long
fun shieldToAddress( suspend fun shieldToAddress(
extsk: String, extsk: String,
tsk: String, tsk: String,
memo: ByteArray? = byteArrayOf() memo: ByteArray? = byteArrayOf()
): Long ): Long
fun decryptAndStoreTransaction(tx: ByteArray) suspend fun decryptAndStoreTransaction(tx: ByteArray)
fun initAccountsTable(seed: ByteArray, numberOfAccounts: Int): Array<UnifiedViewingKey> suspend fun initAccountsTable(seed: ByteArray, numberOfAccounts: Int): Array<UnifiedViewingKey>
fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean suspend fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean
fun initBlocksTable(height: Int, hash: String, time: Long, saplingTree: String): Boolean suspend fun initBlocksTable(height: Int, hash: String, time: Long, saplingTree: String): Boolean
fun initDataDb(): Boolean suspend fun initDataDb(): Boolean
fun isValidShieldedAddr(addr: String): Boolean fun isValidShieldedAddr(addr: String): Boolean
fun isValidTransparentAddr(addr: String): Boolean fun isValidTransparentAddr(addr: String): Boolean
fun getShieldedAddress(account: Int = 0): String suspend fun getShieldedAddress(account: Int = 0): String
fun getTransparentAddress(account: Int = 0, index: Int = 0): String suspend fun getTransparentAddress(account: Int = 0, index: Int = 0): String
fun getBalance(account: Int = 0): Long suspend fun getBalance(account: Int = 0): Long
fun getBranchIdForHeight(height: Int): Long fun getBranchIdForHeight(height: Int): Long
fun getReceivedMemoAsUtf8(idNote: Long): String suspend fun getReceivedMemoAsUtf8(idNote: Long): String
fun getSentMemoAsUtf8(idNote: Long): String suspend fun getSentMemoAsUtf8(idNote: Long): String
fun getVerifiedBalance(account: Int = 0): Long suspend fun getVerifiedBalance(account: Int = 0): Long
// fun parseTransactionDataList(tdl: LocalRpcTypes.TransactionDataList): LocalRpcTypes.TransparentTransactionList // fun parseTransactionDataList(tdl: LocalRpcTypes.TransactionDataList): LocalRpcTypes.TransparentTransactionList
fun getNearestRewindHeight(height: Int): Int suspend fun getNearestRewindHeight(height: Int): Int
fun rewindToHeight(height: Int): Boolean suspend fun rewindToHeight(height: Int): Boolean
fun scanBlocks(limit: Int = -1): Boolean suspend fun scanBlocks(limit: Int = -1): Boolean
fun validateCombinedChain(): Int suspend fun validateCombinedChain(): Int
fun putUtxo( suspend fun putUtxo(
tAddress: String, tAddress: String,
txId: ByteArray, txId: ByteArray,
index: Int, index: Int,
@ -76,59 +76,59 @@ interface RustBackendWelding {
height: Int height: Int
): Boolean ): Boolean
fun clearUtxos(tAddress: String, aboveHeight: Int = network.saplingActivationHeight - 1): Boolean suspend fun clearUtxos(tAddress: String, aboveHeight: Int = network.saplingActivationHeight - 1): Boolean
fun getDownloadedUtxoBalance(address: String): WalletBalance suspend fun getDownloadedUtxoBalance(address: String): WalletBalance
// Implemented by `DerivationTool` // Implemented by `DerivationTool`
interface Derivation { interface Derivation {
fun deriveShieldedAddress( suspend fun deriveShieldedAddress(
viewingKey: String, viewingKey: String,
network: ZcashNetwork network: ZcashNetwork
): String ): String
fun deriveShieldedAddress( suspend fun deriveShieldedAddress(
seed: ByteArray, seed: ByteArray,
network: ZcashNetwork, network: ZcashNetwork,
accountIndex: Int = 0, accountIndex: Int = 0,
): String ): String
fun deriveSpendingKeys( suspend fun deriveSpendingKeys(
seed: ByteArray, seed: ByteArray,
network: ZcashNetwork, network: ZcashNetwork,
numberOfAccounts: Int = 1, numberOfAccounts: Int = 1,
): Array<String> ): Array<String>
fun deriveTransparentAddress( suspend fun deriveTransparentAddress(
seed: ByteArray, seed: ByteArray,
network: ZcashNetwork, network: ZcashNetwork,
account: Int = 0, account: Int = 0,
index: Int = 0, index: Int = 0,
): String ): String
fun deriveTransparentAddressFromPublicKey( suspend fun deriveTransparentAddressFromPublicKey(
publicKey: String, publicKey: String,
network: ZcashNetwork network: ZcashNetwork
): String ): String
fun deriveTransparentAddressFromPrivateKey( suspend fun deriveTransparentAddressFromPrivateKey(
privateKey: String, privateKey: String,
network: ZcashNetwork network: ZcashNetwork
): String ): String
fun deriveTransparentSecretKey( suspend fun deriveTransparentSecretKey(
seed: ByteArray, seed: ByteArray,
network: ZcashNetwork, network: ZcashNetwork,
account: Int = 0, account: Int = 0,
index: Int = 0, index: Int = 0,
): String ): String
fun deriveViewingKey( suspend fun deriveViewingKey(
spendingKey: String, spendingKey: String,
network: ZcashNetwork network: ZcashNetwork
): String ): String
fun deriveUnifiedViewingKeys( suspend fun deriveUnifiedViewingKeys(
seed: ByteArray, seed: ByteArray,
network: ZcashNetwork, network: ZcashNetwork,
numberOfAccounts: Int = 1, numberOfAccounts: Int = 1,

View File

@ -18,7 +18,7 @@ class DerivationTool {
* *
* @return the viewing keys that correspond to the seed, formatted as Strings. * @return the viewing keys that correspond to the seed, formatted as Strings.
*/ */
override fun deriveUnifiedViewingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array<UnifiedViewingKey> = override suspend fun deriveUnifiedViewingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array<UnifiedViewingKey> =
withRustBackendLoaded { withRustBackendLoaded {
deriveUnifiedViewingKeysFromSeed(seed, numberOfAccounts, networkId = network.id).map { deriveUnifiedViewingKeysFromSeed(seed, numberOfAccounts, networkId = network.id).map {
UnifiedViewingKey(it[0], it[1]) UnifiedViewingKey(it[0], it[1])
@ -32,7 +32,7 @@ class DerivationTool {
* *
* @return the viewing key that corresponds to the spending key. * @return the viewing key that corresponds to the spending key.
*/ */
override fun deriveViewingKey(spendingKey: String, network: ZcashNetwork): String = withRustBackendLoaded { override suspend fun deriveViewingKey(spendingKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
deriveExtendedFullViewingKey(spendingKey, networkId = network.id) deriveExtendedFullViewingKey(spendingKey, networkId = network.id)
} }
@ -45,7 +45,7 @@ class DerivationTool {
* *
* @return the spending keys that correspond to the seed, formatted as Strings. * @return the spending keys that correspond to the seed, formatted as Strings.
*/ */
override fun deriveSpendingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array<String> = override suspend fun deriveSpendingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array<String> =
withRustBackendLoaded { withRustBackendLoaded {
deriveExtendedSpendingKeys(seed, numberOfAccounts, networkId = network.id) deriveExtendedSpendingKeys(seed, numberOfAccounts, networkId = network.id)
} }
@ -59,7 +59,7 @@ class DerivationTool {
* *
* @return the address that corresponds to the seed and account index. * @return the address that corresponds to the seed and account index.
*/ */
override fun deriveShieldedAddress(seed: ByteArray, network: ZcashNetwork, accountIndex: Int): String = override suspend fun deriveShieldedAddress(seed: ByteArray, network: ZcashNetwork, accountIndex: Int): String =
withRustBackendLoaded { withRustBackendLoaded {
deriveShieldedAddressFromSeed(seed, accountIndex, networkId = network.id) deriveShieldedAddressFromSeed(seed, accountIndex, networkId = network.id)
} }
@ -72,26 +72,26 @@ class DerivationTool {
* *
* @return the address that corresponds to the viewing key. * @return the address that corresponds to the viewing key.
*/ */
override fun deriveShieldedAddress(viewingKey: String, network: ZcashNetwork): String = withRustBackendLoaded { override suspend fun deriveShieldedAddress(viewingKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
deriveShieldedAddressFromViewingKey(viewingKey, networkId = network.id) deriveShieldedAddressFromViewingKey(viewingKey, networkId = network.id)
} }
// WIP probably shouldn't be used just yet. Why? // WIP probably shouldn't be used just yet. Why?
// - because we need the private key associated with this seed and this function doesn't return it. // - because we need the private key associated with this seed and this function doesn't return it.
// - the underlying implementation needs to be split out into a few lower-level calls // - the underlying implementation needs to be split out into a few lower-level calls
override fun deriveTransparentAddress(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded { override suspend fun deriveTransparentAddress(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded {
deriveTransparentAddressFromSeed(seed, account, index, networkId = network.id) deriveTransparentAddressFromSeed(seed, account, index, networkId = network.id)
} }
override fun deriveTransparentAddressFromPublicKey(transparentPublicKey: String, network: ZcashNetwork): String = withRustBackendLoaded { override suspend fun deriveTransparentAddressFromPublicKey(transparentPublicKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
deriveTransparentAddressFromPubKey(transparentPublicKey, networkId = network.id) deriveTransparentAddressFromPubKey(transparentPublicKey, networkId = network.id)
} }
override fun deriveTransparentAddressFromPrivateKey(transparentPrivateKey: String, network: ZcashNetwork): String = withRustBackendLoaded { override suspend fun deriveTransparentAddressFromPrivateKey(transparentPrivateKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
deriveTransparentAddressFromPrivKey(transparentPrivateKey, networkId = network.id) deriveTransparentAddressFromPrivKey(transparentPrivateKey, networkId = network.id)
} }
override fun deriveTransparentSecretKey(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded { override suspend fun deriveTransparentSecretKey(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded {
deriveTransparentSecretKeyFromSeed(seed, account, index, networkId = network.id) deriveTransparentSecretKeyFromSeed(seed, account, index, networkId = network.id)
} }
@ -104,8 +104,8 @@ class DerivationTool {
* class attempts to interact with it, indirectly, by invoking JNI functions. It would be * class attempts to interact with it, indirectly, by invoking JNI functions. It would be
* nice to have an annotation like @UsesSystemLibrary for this * nice to have an annotation like @UsesSystemLibrary for this
*/ */
private fun <T> withRustBackendLoaded(block: () -> T): T { private suspend fun <T> withRustBackendLoaded(block: () -> T): T {
RustBackend.load() RustBackend.rustLibraryLoader.load()
return block() return block()
} }

View File

@ -8,165 +8,156 @@ import cash.z.ecc.android.sdk.type.WalletBirthday
import cash.z.ecc.android.sdk.type.ZcashNetwork import cash.z.ecc.android.sdk.type.ZcashNetwork
import com.google.gson.Gson import com.google.gson.Gson
import com.google.gson.stream.JsonReader import com.google.gson.stream.JsonReader
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.IOException import java.io.IOException
import java.io.InputStreamReader import java.io.InputStreamReader
import java.util.Locale import java.util.Locale
/** /**
* Tool for loading checkpoints for the wallet, based on the height at which the wallet was born. * Tool for loading checkpoints for the wallet, based on the height at which the wallet was born.
*
* @param appContext needed for loading checkpoints from the app's assets directory.
*/ */
class WalletBirthdayTool(appContext: Context) { object WalletBirthdayTool {
private val context = appContext.applicationContext
// Behavior change implemented as a fix for issue #270. Temporarily adding a boolean
// that allows the change to be rolled back quickly if needed, although long-term
// this flag should be removed.
@VisibleForTesting
internal val IS_FALLBACK_ON_FAILURE = true
/** /**
* Load the nearest checkpoint to the given birthday height. If null is given, then this * Load the nearest checkpoint to the given birthday height. If null is given, then this
* will load the most recent checkpoint available. * will load the most recent checkpoint available.
*/ */
fun loadNearest(network: ZcashNetwork, birthdayHeight: Int? = null): WalletBirthday { suspend fun loadNearest(
context: Context,
network: ZcashNetwork,
birthdayHeight: Int? = null
): WalletBirthday {
// TODO: potentially pull from shared preferences first
return loadBirthdayFromAssets(context, network, birthdayHeight) return loadBirthdayFromAssets(context, network, birthdayHeight)
} }
companion object { /**
* Useful for when an exact checkpoint is needed, like for SAPLING_ACTIVATION_HEIGHT. In
// Behavior change implemented as a fix for issue #270. Temporarily adding a boolean * most cases, loading the nearest checkpoint is preferred for privacy reasons.
// that allows the change to be rolled back quickly if needed, although long-term */
// this flag should be removed. suspend fun loadExact(context: Context, network: ZcashNetwork, birthdayHeight: Int) =
@VisibleForTesting loadNearest(context, network, birthdayHeight).also {
internal val IS_FALLBACK_ON_FAILURE = true if (it.height != birthdayHeight)
throw BirthdayException.ExactBirthdayNotFoundException(
/** birthdayHeight,
* Load the nearest checkpoint to the given birthday height. If null is given, then this it.height
* will load the most recent checkpoint available.
*/
fun loadNearest(
context: Context,
network: ZcashNetwork,
birthdayHeight: Int? = null
): WalletBirthday {
// TODO: potentially pull from shared preferences first
return loadBirthdayFromAssets(context, network, birthdayHeight)
}
/**
* Useful for when an exact checkpoint is needed, like for SAPLING_ACTIVATION_HEIGHT. In
* most cases, loading the nearest checkpoint is preferred for privacy reasons.
*/
fun loadExact(context: Context, network: ZcashNetwork, birthdayHeight: Int) =
loadNearest(context, network, birthdayHeight).also {
if (it.height != birthdayHeight)
throw BirthdayException.ExactBirthdayNotFoundException(
birthdayHeight,
it.height
)
}
// TODO: This method performs disk IO; convert to suspending function
// Converting this to suspending will then propagate
@Throws(IOException::class)
internal fun listBirthdayDirectoryContents(context: Context, directory: String) =
context.assets.list(directory)
/**
* Returns the directory within the assets folder where birthday data
* (i.e. sapling trees for a given height) can be found.
*/
@VisibleForTesting
internal fun birthdayDirectory(network: ZcashNetwork) =
"saplingtree/${(network.networkName as java.lang.String).toLowerCase(Locale.US)}"
internal fun birthdayHeight(fileName: String) = fileName.split('.').first().toInt()
private fun Array<String>.sortDescending() =
apply { sortByDescending { birthdayHeight(it) } }
/**
* Load the given birthday file from the assets of the given context. When no height is
* specified, we default to the file with the greatest name.
*
* @param context the context from which to load assets.
* @param birthdayHeight the height file to look for among the file names.
*
* @return a WalletBirthday that reflects the contents of the file or an exception when
* parsing fails.
*/
private fun loadBirthdayFromAssets(
context: Context,
network: ZcashNetwork,
birthdayHeight: Int? = null
): WalletBirthday {
twig("loading birthday from assets: $birthdayHeight")
val directory = birthdayDirectory(network)
val treeFiles = getFilteredFileNames(context, directory, birthdayHeight)
twig("found ${treeFiles.size} sapling tree checkpoints: $treeFiles")
return getFirstValidWalletBirthday(context, directory, treeFiles)
}
private fun getFilteredFileNames(
context: Context,
directory: String,
birthdayHeight: Int? = null,
): List<String> {
val unfilteredTreeFiles = listBirthdayDirectoryContents(context, directory)
if (unfilteredTreeFiles.isNullOrEmpty()) {
throw BirthdayException.MissingBirthdayFilesException(directory)
}
val filteredTreeFiles = unfilteredTreeFiles
.sortDescending()
.filter { filename ->
birthdayHeight?.let { birthdayHeight(filename) <= it } ?: true
}
if (filteredTreeFiles.isEmpty()) {
throw BirthdayException.BirthdayFileNotFoundException(
directory,
birthdayHeight
) )
}
return filteredTreeFiles
} }
/** // Converting this to suspending will then propagate
* @param treeFiles A list of files, sorted in descending order based on `int` value of the first part of the filename. @Throws(IOException::class)
*/ internal suspend fun listBirthdayDirectoryContents(context: Context, directory: String) =
@VisibleForTesting withContext(Dispatchers.IO) {
internal fun getFirstValidWalletBirthday( context.assets.list(directory)
context: Context, }
directory: String,
treeFiles: List<String> /**
): WalletBirthday { * Returns the directory within the assets folder where birthday data
var lastException: Exception? = null * (i.e. sapling trees for a given height) can be found.
treeFiles.forEach { treefile -> */
try { @VisibleForTesting
internal fun birthdayDirectory(network: ZcashNetwork) =
"saplingtree/${(network.networkName as java.lang.String).toLowerCase(Locale.US)}"
internal fun birthdayHeight(fileName: String) = fileName.split('.').first().toInt()
private fun Array<String>.sortDescending() =
apply { sortByDescending { birthdayHeight(it) } }
/**
* Load the given birthday file from the assets of the given context. When no height is
* specified, we default to the file with the greatest name.
*
* @param context the context from which to load assets.
* @param birthdayHeight the height file to look for among the file names.
*
* @return a WalletBirthday that reflects the contents of the file or an exception when
* parsing fails.
*/
private suspend fun loadBirthdayFromAssets(
context: Context,
network: ZcashNetwork,
birthdayHeight: Int? = null
): WalletBirthday {
twig("loading birthday from assets: $birthdayHeight")
val directory = birthdayDirectory(network)
val treeFiles = getFilteredFileNames(context, directory, birthdayHeight)
twig("found ${treeFiles.size} sapling tree checkpoints: $treeFiles")
return getFirstValidWalletBirthday(context, directory, treeFiles)
}
private suspend fun getFilteredFileNames(
context: Context,
directory: String,
birthdayHeight: Int? = null,
): List<String> {
val unfilteredTreeFiles = listBirthdayDirectoryContents(context, directory)
if (unfilteredTreeFiles.isNullOrEmpty()) {
throw BirthdayException.MissingBirthdayFilesException(directory)
}
val filteredTreeFiles = unfilteredTreeFiles
.sortDescending()
.filter { filename ->
birthdayHeight?.let { birthdayHeight(filename) <= it } ?: true
}
if (filteredTreeFiles.isEmpty()) {
throw BirthdayException.BirthdayFileNotFoundException(
directory,
birthdayHeight
)
}
return filteredTreeFiles
}
/**
* @param treeFiles A list of files, sorted in descending order based on `int` value of the first part of the filename.
*/
@VisibleForTesting
internal suspend fun getFirstValidWalletBirthday(
context: Context,
directory: String,
treeFiles: List<String>
): WalletBirthday {
var lastException: Exception? = null
treeFiles.forEach { treefile ->
try {
return withContext(Dispatchers.IO) {
context.assets.open("$directory/$treefile").use { inputStream -> context.assets.open("$directory/$treefile").use { inputStream ->
InputStreamReader(inputStream).use { inputStreamReader -> InputStreamReader(inputStream).use { inputStreamReader ->
JsonReader(inputStreamReader).use { jsonReader -> JsonReader(inputStreamReader).use { jsonReader ->
return Gson().fromJson(jsonReader, WalletBirthday::class.java) Gson().fromJson(jsonReader, WalletBirthday::class.java)
} }
} }
} }
} catch (t: Throwable) { }
val exception = BirthdayException.MalformattedBirthdayFilesException( } catch (t: Throwable) {
directory, val exception = BirthdayException.MalformattedBirthdayFilesException(
treefile directory,
) treefile
lastException = exception )
lastException = exception
if (IS_FALLBACK_ON_FAILURE) { if (IS_FALLBACK_ON_FAILURE) {
// TODO: If we ever add crash analytics hooks, this would be something to report // TODO: If we ever add crash analytics hooks, this would be something to report
twig("Malformed birthday file $t") twig("Malformed birthday file $t")
} else { } else {
throw exception throw exception
}
} }
} }
throw lastException!!
} }
throw lastException!!
} }
} }