[#269] Convert blocking calls to suspend functions

To quickly get this implemented, some calls in the demo-app have been wrapped in `runBlocking {}`.  This is not ideal and will be addressed in followup issues.
This commit is contained in:
Carter Jernigan 2021-10-21 16:05:02 -04:00
parent 07a00dc376
commit 079229972f
37 changed files with 803 additions and 479 deletions

View File

@ -5,6 +5,8 @@ Upcoming Migrating to Version 1.4.* from 1.3.*
--------------------------------------
Various APIs that have always been considered private have been moved into a new package called `internal`. While this should not be a breaking change, clients that might have relied on these internal classes should stop doing so.
A number of methods have been converted to suspending functions, because they were performing slow or blocking calls (e.g. disk IO) internally. This is a breaking change.
Migrating to Version 1.3.* from 1.2.*
--------------------------------------
The biggest breaking changes in 1.3 that inspired incrementing the minor version number was simplifying down to one "network aware" library rather than two separate libraries, each dedicated to either testnet or mainnet. This greatly simplifies the gradle configuration and has lots of other benefits. Wallets can now set a network with code similar to the following:

View File

@ -141,8 +141,10 @@ class DarksideTestCoordinator(val wallet: TestWallet) {
inner class DarksideTestValidator {
fun validateHasBlock(height: Int) {
assertTrue((synchronizer as SdkSynchronizer).findBlockHashAsHex(height) != null)
assertTrue((synchronizer as SdkSynchronizer).findBlockHash(height)?.size ?: 0 > 0)
runBlocking {
assertTrue((synchronizer as SdkSynchronizer).findBlockHashAsHex(height) != null)
assertTrue((synchronizer as SdkSynchronizer).findBlockHash(height)?.size ?: 0 > 0)
}
}
fun validateLatestHeight(height: Int) = runBlocking<Unit> {
@ -185,7 +187,7 @@ class DarksideTestCoordinator(val wallet: TestWallet) {
}
fun validateBlockHash(height: Int, expectedHash: String) {
val hash = (synchronizer as SdkSynchronizer).findBlockHashAsHex(height)
val hash = runBlocking { (synchronizer as SdkSynchronizer).findBlockHashAsHex(height) }
assertEquals(expectedHash, hash)
}
@ -194,7 +196,7 @@ class DarksideTestCoordinator(val wallet: TestWallet) {
}
fun validateTxCount(count: Int) {
val txCount = (synchronizer as SdkSynchronizer).getTransactionCount()
val txCount = runBlocking { (synchronizer as SdkSynchronizer).getTransactionCount() }
assertEquals("Expected $count transactions but found $txCount instead!", count, txCount)
}

View File

@ -23,6 +23,7 @@ import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.takeWhile
import kotlinx.coroutines.launch
import kotlinx.coroutines.newFixedThreadPoolContext
import kotlinx.coroutines.runBlocking
import java.util.concurrent.TimeoutException
/**
@ -51,19 +52,29 @@ class TestWallet(
val walletScope = CoroutineScope(
SupervisorJob() + newFixedThreadPoolContext(3, this.javaClass.simpleName)
)
// Although runBlocking isn't great, this usage is OK because this is only used within the
// automated tests
private val context = InstrumentationRegistry.getInstrumentation().context
private val seed: ByteArray = Mnemonics.MnemonicCode(seedPhrase).toSeed()
private val shieldedSpendingKey = DerivationTool.deriveSpendingKeys(seed, network = network)[0]
private val transparentSecretKey = DerivationTool.deriveTransparentSecretKey(seed, network = network)
val initializer = Initializer(context) { config ->
config.importWallet(seed, startHeight, network, host, alias = alias)
private val shieldedSpendingKey =
runBlocking { DerivationTool.deriveSpendingKeys(seed, network = network)[0] }
private val transparentSecretKey =
runBlocking { DerivationTool.deriveTransparentSecretKey(seed, network = network) }
val initializer = runBlocking {
Initializer.new(context) { config ->
runBlocking { config.importWallet(seed, startHeight, network, host, alias = alias) }
}
}
val synchronizer: SdkSynchronizer = Synchronizer(initializer) as SdkSynchronizer
val service = (synchronizer.processor.downloader.lightWalletService as LightWalletGrpcService)
val available get() = synchronizer.saplingBalances.value.availableZatoshi
val shieldedAddress = DerivationTool.deriveShieldedAddress(seed, network = network)
val transparentAddress = DerivationTool.deriveTransparentAddress(seed, network = network)
val shieldedAddress =
runBlocking { DerivationTool.deriveShieldedAddress(seed, network = network) }
val transparentAddress =
runBlocking { DerivationTool.deriveTransparentAddress(seed, network = network) }
val birthdayHeight get() = synchronizer.latestBirthdayHeight
val networkName get() = synchronizer.network.networkName
val connectionInfo get() = service.connectionInfo.toString()

View File

@ -52,7 +52,12 @@ class SampleCodeTest {
// ///////////////////////////////////////////////////
// Derive Extended Spending Key
@Test fun deriveSpendingKey() {
val spendingKeys = DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.Mainnet)
val spendingKeys = runBlocking {
DerivationTool.deriveSpendingKeys(
seed,
ZcashNetwork.Mainnet
)
}
assertEquals(1, spendingKeys.size)
log("Spending Key: ${spendingKeys?.get(0)}")
}
@ -140,7 +145,7 @@ class SampleCodeTest {
private val lightwalletdHost: String = ZcashNetwork.Mainnet.defaultHost
private val context = InstrumentationRegistry.getInstrumentation().targetContext
private val synchronizer = Synchronizer(Initializer(context) {})
private val synchronizer = Synchronizer(runBlocking { Initializer.new(context) {} })
@BeforeClass
@JvmStatic

View File

@ -2,6 +2,7 @@ package cash.z.ecc.android.sdk.demoapp.demos.getaddress
import android.os.Bundle
import android.view.LayoutInflater
import androidx.lifecycle.lifecycleScope
import cash.z.ecc.android.bip39.Mnemonics
import cash.z.ecc.android.bip39.toSeed
import cash.z.ecc.android.sdk.demoapp.BaseDemoFragment
@ -11,6 +12,8 @@ import cash.z.ecc.android.sdk.demoapp.util.fromResources
import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.UnifiedViewingKey
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
/**
* Displays the address associated with the seed defined by the default config. To modify the seed
@ -34,14 +37,16 @@ class GetAddressFragment : BaseDemoFragment<FragmentGetAddressBinding>() {
seed = Mnemonics.MnemonicCode(seedPhrase).toSeed()
// the derivation tool can be used for generating keys and addresses
viewingKey = DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first()
viewingKey = runBlocking { DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() }
}
private fun displayAddress() {
// a full fledged app would just get the address from the synchronizer
val zaddress = DerivationTool.deriveShieldedAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))
val taddress = DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))
binding.textInfo.text = "z-addr:\n$zaddress\n\n\nt-addr:\n$taddress"
viewLifecycleOwner.lifecycleScope.launchWhenStarted {
val zaddress = DerivationTool.deriveShieldedAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))
val taddress = DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))
binding.textInfo.text = "z-addr:\n$zaddress\n\n\nt-addr:\n$taddress"
}
}
// TODO: show an example with the synchronizer
@ -65,10 +70,15 @@ class GetAddressFragment : BaseDemoFragment<FragmentGetAddressBinding>() {
//
override fun onActionButtonClicked() {
copyToClipboard(
DerivationTool.deriveShieldedAddress(viewingKey.extfvk, ZcashNetwork.fromResources(requireApplicationContext())),
"Shielded address copied to clipboard!"
)
viewLifecycleOwner.lifecycleScope.launch {
copyToClipboard(
DerivationTool.deriveShieldedAddress(
viewingKey.extfvk,
ZcashNetwork.fromResources(requireApplicationContext())
),
"Shielded address copied to clipboard!"
)
}
}
override fun inflateBinding(layoutInflater: LayoutInflater): FragmentGetAddressBinding =

View File

@ -17,6 +17,7 @@ import cash.z.ecc.android.sdk.ext.convertZatoshiToZecString
import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.WalletBalance
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
/**
* Displays the available balance && total balance associated with the seed defined by the default config.
@ -43,13 +44,13 @@ class GetBalanceFragment : BaseDemoFragment<FragmentGetBalanceBinding>() {
val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed()
// converting seed into viewingKey
val viewingKey = DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first()
val viewingKey = runBlocking { DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() }
// using the ViewingKey to initialize
Initializer(requireApplicationContext()) {
runBlocking {Initializer.new(requireApplicationContext(), null) {
it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext()))
it.importWallet(viewingKey, network = ZcashNetwork.fromResources(requireApplicationContext()))
}.let { initializer ->
}}.let { initializer ->
synchronizer = Synchronizer(initializer)
}
}

View File

@ -2,6 +2,7 @@ package cash.z.ecc.android.sdk.demoapp.demos.getprivatekey
import android.os.Bundle
import android.view.LayoutInflater
import androidx.lifecycle.lifecycleScope
import cash.z.ecc.android.bip39.Mnemonics
import cash.z.ecc.android.bip39.toSeed
import cash.z.ecc.android.sdk.demoapp.BaseDemoFragment
@ -10,6 +11,7 @@ import cash.z.ecc.android.sdk.demoapp.ext.requireApplicationContext
import cash.z.ecc.android.sdk.demoapp.util.fromResources
import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.launch
/**
* Displays the viewing key and spending key associated with the seed used during the demo. The
@ -37,13 +39,22 @@ class GetPrivateKeyFragment : BaseDemoFragment<FragmentGetPrivateKeyBinding>() {
private fun displayKeys() {
// derive the keys from the seed:
// demonstrate deriving spending keys for five accounts but only take the first one
val spendingKey = DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext()), 5).first()
lifecycleScope.launchWhenStarted {
val spendingKey = DerivationTool.deriveSpendingKeys(
seed,
ZcashNetwork.fromResources(requireApplicationContext()),
5
).first()
// derive the key that allows you to view but not spend transactions
val viewingKey = DerivationTool.deriveViewingKey(spendingKey, ZcashNetwork.fromResources(requireApplicationContext()))
// derive the key that allows you to view but not spend transactions
val viewingKey = DerivationTool.deriveViewingKey(
spendingKey,
ZcashNetwork.fromResources(requireApplicationContext())
)
// display the keys in the UI
binding.textInfo.setText("Spending Key:\n$spendingKey\n\nViewing Key:\n$viewingKey")
// display the keys in the UI
binding.textInfo.setText("Spending Key:\n$spendingKey\n\nViewing Key:\n$viewingKey")
}
}
//
@ -65,10 +76,15 @@ class GetPrivateKeyFragment : BaseDemoFragment<FragmentGetPrivateKeyBinding>() {
//
override fun onActionButtonClicked() {
copyToClipboard(
DerivationTool.deriveUnifiedViewingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first().extpub,
"ViewingKey copied to clipboard!"
)
lifecycleScope.launch {
copyToClipboard(
DerivationTool.deriveUnifiedViewingKeys(
seed,
ZcashNetwork.fromResources(requireApplicationContext())
).first().extpub,
"ViewingKey copied to clipboard!"
)
}
}
override fun inflateBinding(layoutInflater: LayoutInflater): FragmentGetPrivateKeyBinding =

View File

@ -19,6 +19,7 @@ import cash.z.ecc.android.sdk.ext.collectWith
import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
/**
* List all transactions related to the given seed, since the given birthday. This begins by
@ -47,11 +48,16 @@ class ListTransactionsFragment : BaseDemoFragment<FragmentListTransactionsBindin
// have the seed stored
val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed()
initializer = Initializer(requireApplicationContext()) {
it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext()))
initializer = runBlocking {Initializer.new(requireApplicationContext()) {
runBlocking { it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) }
it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext()))
}}
address = runBlocking {
DerivationTool.deriveShieldedAddress(
seed,
ZcashNetwork.fromResources(requireApplicationContext())
)
}
address = DerivationTool.deriveShieldedAddress(seed, ZcashNetwork.fromResources(requireApplicationContext()))
synchronizer = Synchronizer(initializer)
}

View File

@ -25,6 +25,7 @@ import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
/**
@ -60,10 +61,10 @@ class ListUtxosFragment : BaseDemoFragment<FragmentListUtxosBinding>() {
// Use a BIP-39 library to convert a seed phrase into a byte array. Most wallets already
// have the seed stored
seed = Mnemonics.MnemonicCode(sharedViewModel.seedPhrase.value).toSeed()
initializer = Initializer(requireApplicationContext()) {
it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext()))
initializer = runBlocking {Initializer.new(requireApplicationContext()) {
runBlocking { it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) }
it.alias = "Demo_Utxos"
}
}}
synchronizer = Synchronizer(initializer)
}
@ -102,7 +103,7 @@ class ListUtxosFragment : BaseDemoFragment<FragmentListUtxosBinding>() {
txids?.map {
it.data.apply {
try {
initializer.rustBackend.decryptAndStoreTransaction(toByteArray())
runBlocking { initializer.rustBackend.decryptAndStoreTransaction(toByteArray()) }
} catch (t: Throwable) {
twig("failed to decrypt and store transaction due to: $t")
}
@ -154,7 +155,9 @@ class ListUtxosFragment : BaseDemoFragment<FragmentListUtxosBinding>() {
super.onResume()
resetInBackground()
val seed = Mnemonics.MnemonicCode(sharedViewModel.seedPhrase.value).toSeed()
binding.inputAddress.setText(DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext())))
viewLifecycleOwner.lifecycleScope.launchWhenStarted {
binding.inputAddress.setText(DerivationTool.deriveTransparentAddress(seed, ZcashNetwork.fromResources(requireApplicationContext())))
}
}
var initialCount: Int = 0

View File

@ -32,6 +32,7 @@ import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.WalletBalance
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
/**
* Demonstrates sending funds to an address. This is the most complex example that puts all of the
@ -63,13 +64,13 @@ class SendFragment : BaseDemoFragment<FragmentSendBinding>() {
// have the seed stored
val seed = Mnemonics.MnemonicCode(seedPhrase).toSeed()
Initializer(requireApplicationContext()) {
it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext()))
runBlocking { Initializer.new(requireApplicationContext()) {
runBlocking { it.importWallet(seed, network = ZcashNetwork.fromResources(requireApplicationContext())) }
it.setNetwork(ZcashNetwork.fromResources(requireApplicationContext()))
}.let { initializer ->
}}.let { initializer ->
synchronizer = Synchronizer(initializer)
}
spendingKey = DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first()
spendingKey = runBlocking { DerivationTool.deriveSpendingKeys(seed, ZcashNetwork.fromResources(requireApplicationContext())).first() }
}
//

View File

@ -5,6 +5,7 @@ import androidx.test.core.app.ApplicationProvider
import androidx.test.filters.SmallTest
import cash.z.ecc.android.sdk.tool.WalletBirthdayTool
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
import org.json.JSONObject
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
@ -92,9 +93,9 @@ class AssetTest {
private data class JsonFile(val jsonObject: JSONObject, val filename: String)
companion object {
fun listAssets(network: ZcashNetwork) = WalletBirthdayTool.listBirthdayDirectoryContents(
fun listAssets(network: ZcashNetwork) = runBlocking { WalletBirthdayTool.listBirthdayDirectoryContents(
ApplicationProvider.getApplicationContext<Context>(),
WalletBirthdayTool.birthdayDirectory(network)
)
WalletBirthdayTool.birthdayDirectory(network))
}
}
}

View File

@ -3,13 +3,14 @@ package cash.z.ecc.android.sdk.ext
import cash.z.ecc.android.sdk.Initializer
import cash.z.ecc.android.sdk.type.ZcashNetwork
import cash.z.ecc.android.sdk.util.SimpleMnemonics
import kotlinx.coroutines.runBlocking
import okhttp3.OkHttpClient
import okhttp3.Request
import org.json.JSONObject
import ru.gildor.coroutines.okhttp.await
fun Initializer.Config.seedPhrase(seedPhrase: String, network: ZcashNetwork) {
setSeed(SimpleMnemonics().toSeed(seedPhrase.toCharArray()), network)
runBlocking { setSeed(SimpleMnemonics().toSeed(seedPhrase.toCharArray()), network) }
}
object BlockExplorer {

View File

@ -46,7 +46,13 @@ class TestnetIntegrationTest : ScopedTest() {
@Test
fun testLoadBirthday() {
val (height, hash, time, tree) = WalletBirthdayTool.loadNearest(context, synchronizer.network, saplingActivation + 1)
val (height, hash, time, tree) = runBlocking {
WalletBirthdayTool.loadNearest(
context,
synchronizer.network,
saplingActivation + 1
)
}
assertEquals(saplingActivation, height)
}
@ -118,9 +124,11 @@ class TestnetIntegrationTest : ScopedTest() {
val toAddress = "zs1vp7kvlqr4n9gpehztr76lcn6skkss9p8keqs3nv8avkdtjrcctrvmk9a7u494kluv756jeee5k0"
private val context = InstrumentationRegistry.getInstrumentation().context
private val initializer = Initializer(context) { config ->
config.setNetwork(ZcashNetwork.Testnet, host)
config.importWallet(seed, birthdayHeight, ZcashNetwork.Testnet)
private val initializer = runBlocking {
Initializer.new(context) { config ->
config.setNetwork(ZcashNetwork.Testnet, host)
runBlocking { config.importWallet(seed, birthdayHeight, ZcashNetwork.Testnet) }
}
}
private lateinit var synchronizer: Synchronizer

View File

@ -3,6 +3,7 @@ package cash.z.ecc.android.sdk.jni
import cash.z.ecc.android.sdk.annotation.MaintainedTest
import cash.z.ecc.android.sdk.annotation.TestPurpose
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
@ -43,8 +44,8 @@ class BranchIdTest(
// is an abnormal use of the SDK because this really should run at the rust level
// However, due to quirks on certain devices, we created this test at the Android level,
// as a sanity check
val testnetBackend = RustBackend.init("", "", "", ZcashNetwork.Testnet)
val mainnetBackend = RustBackend.init("", "", "", ZcashNetwork.Mainnet)
val testnetBackend = runBlocking { RustBackend.init("", "", "", ZcashNetwork.Testnet) }
val mainnetBackend = runBlocking { RustBackend.init("", "", "", ZcashNetwork.Mainnet) }
return listOf(
// Mainnet Cases
arrayOf("Sapling", 419_200, 1991772603L, "76b809bb", mainnetBackend),

View File

@ -9,6 +9,7 @@ import cash.z.ecc.android.sdk.internal.TroubleshootingTwig
import cash.z.ecc.android.sdk.internal.Twig
import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals
import org.junit.BeforeClass
import org.junit.Test
@ -21,23 +22,23 @@ import org.junit.runners.Parameterized
class TransparentTest(val expected: Expected, val network: ZcashNetwork) {
@Test
fun deriveTransparentSecretKeyTest() {
fun deriveTransparentSecretKeyTest() = runBlocking {
assertEquals(expected.tskCompressed, DerivationTool.deriveTransparentSecretKey(SEED, network = network))
}
@Test
fun deriveTransparentAddressTest() {
fun deriveTransparentAddressTest() = runBlocking {
assertEquals(expected.tAddr, DerivationTool.deriveTransparentAddress(SEED, network = network))
}
@Test
fun deriveTransparentAddressFromSecretKeyTest() {
fun deriveTransparentAddressFromSecretKeyTest() = runBlocking {
val pk = DerivationTool.deriveTransparentSecretKey(SEED, network = network)
assertEquals(expected.tAddr, DerivationTool.deriveTransparentAddressFromPrivateKey(pk, network = network))
}
@Test
fun deriveUnifiedViewingKeysFromSeedTest() {
fun deriveUnifiedViewingKeysFromSeedTest() = runBlocking {
val uvks = DerivationTool.deriveUnifiedViewingKeys(SEED, network = network)
assertEquals(1, uvks.size)
val uvk = uvks.first()

View File

@ -4,7 +4,8 @@ import android.content.Context
import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.SmallTest
import cash.z.ecc.android.sdk.tool.WalletBirthdayTool.Companion.IS_FALLBACK_ON_FAILURE
import cash.z.ecc.android.sdk.tool.WalletBirthdayTool.IS_FALLBACK_ON_FAILURE
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
@ -25,11 +26,13 @@ class WalletBirthdayToolTest {
val directory = "saplingtree/goodnet"
val context = ApplicationProvider.getApplicationContext<Context>()
val birthday = WalletBirthdayTool.getFirstValidWalletBirthday(
context,
directory,
listOf("1300000.json", "1290000.json")
)
val birthday = runBlocking {
WalletBirthdayTool.getFirstValidWalletBirthday(
context,
directory,
listOf("1300000.json", "1290000.json")
)
}
assertEquals(1300000, birthday.height)
}
@ -42,11 +45,13 @@ class WalletBirthdayToolTest {
val directory = "saplingtree/badnet"
val context = ApplicationProvider.getApplicationContext<Context>()
val birthday = WalletBirthdayTool.getFirstValidWalletBirthday(
context,
directory,
listOf("1300000.json", "1290000.json")
)
val birthday = runBlocking {
WalletBirthdayTool.getFirstValidWalletBirthday(
context,
directory,
listOf("1300000.json", "1290000.json")
)
}
assertEquals(1290000, birthday.height)
}
}

View File

@ -5,6 +5,7 @@ import cash.z.ecc.android.sdk.Initializer
import cash.z.ecc.android.sdk.Synchronizer
import cash.z.ecc.android.sdk.internal.TroubleshootingTwig
import cash.z.ecc.android.sdk.internal.Twig
import cash.z.ecc.android.sdk.internal.ext.deleteSuspend
import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.tool.WalletBirthdayTool
import cash.z.ecc.android.sdk.type.WalletBirthday
@ -52,7 +53,7 @@ class BalancePrinterUtil {
fun setup() {
Twig.plant(TroubleshootingTwig())
cacheBlocks()
birthday = WalletBirthdayTool.loadNearest(context, network, birthdayHeight)
birthday = runBlocking { WalletBirthdayTool.loadNearest(context, network, birthdayHeight) }
}
private fun cacheBlocks() = runBlocking {
@ -66,8 +67,8 @@ class BalancePrinterUtil {
// assertEquals(-1, error)
}
private fun deleteDb(dbName: String) {
context.getDatabasePath(dbName).absoluteFile.delete()
private suspend fun deleteDb(dbName: String) {
context.getDatabasePath(dbName).absoluteFile.deleteSuspend()
}
@Test
@ -79,8 +80,8 @@ class BalancePrinterUtil {
mnemonics.toSeed(seedPhrase.toCharArray())
}.collect { seed ->
// TODO: clear the dataDb but leave the cacheDb
val initializer = Initializer(context) { config ->
config.importWallet(seed, birthdayHeight, network)
val initializer = Initializer.new(context) { config ->
runBlocking { config.importWallet(seed, birthdayHeight, network) }
config.setNetwork(network)
config.alias = alias
}

View File

@ -64,7 +64,13 @@ class DataDbScannerUtil {
@Test
@Ignore("This test is broken")
fun scanExistingDb() {
synchronizer = Synchronizer(Initializer(context) { it.setBirthdayHeight(birthdayHeight) })
synchronizer = Synchronizer(runBlocking {
Initializer.new(context) {
it.setBirthdayHeight(
birthdayHeight
)
}
})
println("sync!")
synchronizer.start()

View File

@ -23,6 +23,7 @@ import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.takeWhile
import kotlinx.coroutines.launch
import kotlinx.coroutines.newFixedThreadPoolContext
import kotlinx.coroutines.runBlocking
import java.util.concurrent.TimeoutException
/**
@ -51,19 +52,29 @@ class TestWallet(
val walletScope = CoroutineScope(
SupervisorJob() + newFixedThreadPoolContext(3, this.javaClass.simpleName)
)
// Although runBlocking isn't great, this usage is OK because this is only used within the
// automated tests
private val context = InstrumentationRegistry.getInstrumentation().context
private val seed: ByteArray = Mnemonics.MnemonicCode(seedPhrase).toSeed()
private val shieldedSpendingKey = DerivationTool.deriveSpendingKeys(seed, network = network)[0]
private val transparentSecretKey = DerivationTool.deriveTransparentSecretKey(seed, network = network)
val initializer = Initializer(context) { config ->
config.importWallet(seed, startHeight, network, host, alias = alias)
private val shieldedSpendingKey =
runBlocking { DerivationTool.deriveSpendingKeys(seed, network = network)[0] }
private val transparentSecretKey =
runBlocking { DerivationTool.deriveTransparentSecretKey(seed, network = network) }
val initializer = runBlocking {
Initializer.new(context) { config ->
runBlocking { config.importWallet(seed, startHeight, network, host, alias = alias) }
}
}
val synchronizer: SdkSynchronizer = Synchronizer(initializer) as SdkSynchronizer
val service = (synchronizer.processor.downloader.lightWalletService as LightWalletGrpcService)
val available get() = synchronizer.saplingBalances.value.availableZatoshi
val shieldedAddress = DerivationTool.deriveShieldedAddress(seed, network = network)
val transparentAddress = DerivationTool.deriveTransparentAddress(seed, network = network)
val shieldedAddress =
runBlocking { DerivationTool.deriveShieldedAddress(seed, network = network) }
val transparentAddress =
runBlocking { DerivationTool.deriveTransparentAddress(seed, network = network) }
val birthdayHeight get() = synchronizer.latestBirthdayHeight
val networkName get() = synchronizer.network.networkName
val connectionInfo get() = service.connectionInfo.toString()

View File

@ -4,95 +4,37 @@ import android.content.Context
import cash.z.ecc.android.sdk.exception.InitializerException
import cash.z.ecc.android.sdk.ext.ZcashSdk
import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.internal.SdkDispatchers
import cash.z.ecc.android.sdk.internal.ext.getCacheDirSuspend
import cash.z.ecc.android.sdk.internal.ext.getDatabasePathSuspend
import cash.z.ecc.android.sdk.jni.RustBackend
import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.tool.WalletBirthdayTool
import cash.z.ecc.android.sdk.type.UnifiedViewingKey
import cash.z.ecc.android.sdk.type.WalletBirthday
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File
/**
* Simplified Initializer focused on starting from a ViewingKey.
*/
class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Throwable?) -> Boolean)? = null, config: Config) {
val context = appContext.applicationContext
val rustBackend: RustBackend
val network: ZcashNetwork
val alias: String
val host: String
val port: Int
val viewingKeys: List<UnifiedViewingKey>
val overwriteVks: Boolean
class Initializer private constructor(
val context: Context,
val rustBackend: RustBackend,
val network: ZcashNetwork,
val alias: String,
val host: String,
val port: Int,
val viewingKeys: List<UnifiedViewingKey>,
val overwriteVks: Boolean,
val birthday: WalletBirthday
) {
/**
* A callback to invoke whenever an uncaught error is encountered. By definition, the return
* value of the function is ignored because this error is unrecoverable. The only reason the
* function has a return value is so that all error handlers work with the same signature which
* allows one function to handle all errors in simple apps.
*/
var onCriticalErrorHandler: ((Throwable?) -> Boolean)? = onCriticalErrorHandler
suspend fun erase() = erase(context, network, alias)
init {
try {
config.validate()
network = config.network
val heightToUse = config.birthdayHeight
?: (if (config.defaultToOldestHeight == true) network.saplingActivationHeight else null)
val loadedBirthday = WalletBirthdayTool.loadNearest(context, network, heightToUse)
birthday = loadedBirthday
viewingKeys = config.viewingKeys
overwriteVks = config.overwriteVks
alias = config.alias
host = config.host
port = config.port
rustBackend = initRustBackend(network, birthday)
} catch (t: Throwable) {
onCriticalError(t)
throw t
}
}
constructor(appContext: Context, config: Config) : this(appContext, null, config)
constructor(appContext: Context, onCriticalErrorHandler: ((Throwable?) -> Boolean)? = null, block: (Config) -> Unit) : this(appContext, onCriticalErrorHandler, Config(block))
fun erase() = erase(context, network, alias)
private fun initRustBackend(network: ZcashNetwork, birthday: WalletBirthday): RustBackend {
return RustBackend.init(
cacheDbPath(context, network, alias),
dataDbPath(context, network, alias),
"${context.cacheDir.absolutePath}/params",
network,
birthday.height
)
}
private fun onCriticalError(error: Throwable) {
twig("********")
twig("******** INITIALIZER ERROR: $error")
if (error.cause != null) twig("******** caused by ${error.cause}")
if (error.cause?.cause != null) twig("******** caused by ${error.cause?.cause}")
twig("********")
twig(error)
if (onCriticalErrorHandler == null) {
twig(
"WARNING: a critical error occurred on the Initializer but no callback is " +
"registered to be notified of critical errors! THIS IS PROBABLY A MISTAKE. To " +
"respond to these errors (perhaps to update the UI or alert the user) set " +
"initializer.onCriticalErrorHandler to a non-null value or use the secondary " +
"constructor: Initializer(context, handler) { ... }. Note that the synchronizer " +
"and initializer BOTH have error handlers and since the initializer exists " +
"before the synchronizer, it needs its error handler set separately."
)
}
onCriticalErrorHandler?.invoke(error)
}
class Config private constructor (
class Config private constructor(
val viewingKeys: MutableList<UnifiedViewingKey> = mutableListOf(),
var alias: String = ZcashSdk.DEFAULT_ALIAS,
) {
@ -177,7 +119,10 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* is not currently well supported. Consider it an alpha-preview feature that might work but
* probably has serious bugs.
*/
fun setViewingKeys(vararg unifiedViewingKeys: UnifiedViewingKey, overwrite: Boolean = false): Config = apply {
fun setViewingKeys(
vararg unifiedViewingKeys: UnifiedViewingKey,
overwrite: Boolean = false
): Config = apply {
overwriteVks = overwrite
viewingKeys.apply {
clear()
@ -225,7 +170,7 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
/**
* Import a wallet using the first viewing key derived from the given seed.
*/
fun importWallet(
suspend fun importWallet(
seed: ByteArray,
birthdayHeight: Int? = null,
network: ZcashNetwork,
@ -262,7 +207,7 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
/**
* Create a new wallet using the first viewing key derived from the given seed.
*/
fun newWallet(
suspend fun newWallet(
seed: ByteArray,
network: ZcashNetwork,
host: String = network.defaultHost,
@ -296,9 +241,20 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* Convenience method for setting thew viewingKeys from a given seed. This is the same as
* calling `setViewingKeys` with the keys that match this seed.
*/
fun setSeed(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int = 1): Config = apply {
setViewingKeys(*DerivationTool.deriveUnifiedViewingKeys(seed, network, numberOfAccounts))
}
suspend fun setSeed(
seed: ByteArray,
network: ZcashNetwork,
numberOfAccounts: Int = 1
): Config =
apply {
setViewingKeys(
*DerivationTool.deriveUnifiedViewingKeys(
seed,
network,
numberOfAccounts
)
)
}
/**
* Sets the network from a network id, throwing an exception if the id is not recognized.
@ -338,16 +294,89 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
private fun validateViewingKeys() {
require(viewingKeys.isNotEmpty()) {
"Unified Viewing keys are required. Ensure that the unified viewing keys or seed" +
" have been set on this Initializer."
" have been set on this Initializer."
}
viewingKeys.forEach {
DerivationTool.validateUnifiedViewingKey(it)
}
}
}
companion object : SdkSynchronizer.Erasable {
suspend fun new(appContext: Context, config: Config) = new(appContext, null, config)
suspend fun new(
appContext: Context,
onCriticalErrorHandler: ((Throwable?) -> Boolean)? = null,
block: (Config) -> Unit
) = new(appContext, onCriticalErrorHandler, Config(block))
suspend fun new(
context: Context,
onCriticalErrorHandler: ((Throwable?) -> Boolean)?,
config: Config
): Initializer {
config.validate()
val heightToUse = config.birthdayHeight
?: (if (config.defaultToOldestHeight == true) config.network.saplingActivationHeight else null)
val loadedBirthday =
WalletBirthdayTool.loadNearest(context, config.network, heightToUse)
val rustBackend = initRustBackend(context, config.network, config.alias, loadedBirthday)
return Initializer(
context.applicationContext,
rustBackend,
config.network,
config.alias,
config.host,
config.port,
config.viewingKeys,
config.overwriteVks,
loadedBirthday
)
}
private fun onCriticalError(onCriticalErrorHandler: ((Throwable?) -> Boolean)?, error: Throwable) {
twig("********")
twig("******** INITIALIZER ERROR: $error")
if (error.cause != null) twig("******** caused by ${error.cause}")
if (error.cause?.cause != null) twig("******** caused by ${error.cause?.cause}")
twig("********")
twig(error)
if (onCriticalErrorHandler == null) {
twig(
"WARNING: a critical error occurred on the Initializer but no callback is " +
"registered to be notified of critical errors! THIS IS PROBABLY A MISTAKE. To " +
"respond to these errors (perhaps to update the UI or alert the user) set " +
"initializer.onCriticalErrorHandler to a non-null value or use the secondary " +
"constructor: Initializer(context, handler) { ... }. Note that the synchronizer " +
"and initializer BOTH have error handlers and since the initializer exists " +
"before the synchronizer, it needs its error handler set separately."
)
}
onCriticalErrorHandler?.invoke(error)
}
private suspend fun initRustBackend(
context: Context,
network: ZcashNetwork,
alias: String,
birthday: WalletBirthday
): RustBackend {
return RustBackend.init(
cacheDbPath(context, network, alias),
dataDbPath(context, network, alias),
File(context.getCacheDirSuspend(), "params").absolutePath,
network,
birthday.height
)
}
/**
* Delete the databases associated with this wallet. This removes all compact blocks and
* data derived from those blocks. For most wallets, this should not result in a loss of
@ -362,7 +391,11 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* @return true when one of the associated files was found. False most likely indicates
* that the wrong alias was provided.
*/
override fun erase(appContext: Context, network: ZcashNetwork, alias: String): Boolean {
override suspend fun erase(
appContext: Context,
network: ZcashNetwork,
alias: String
): Boolean {
val cacheDeleted = deleteDb(cacheDbPath(appContext, network, alias))
val dataDeleted = deleteDb(dataDbPath(appContext, network, alias))
return dataDeleted || cacheDeleted
@ -379,7 +412,11 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* @param network the network associated with the data in the cache database.
* @param alias the alias to convert into a database path
*/
internal fun cacheDbPath(appContext: Context, network: ZcashNetwork, alias: String): String =
private suspend fun cacheDbPath(
appContext: Context,
network: ZcashNetwork,
alias: String
): String =
aliasToPath(appContext, network, alias, ZcashSdk.DB_CACHE_NAME)
/**
@ -388,12 +425,21 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* * @param network the network associated with the data in the database.
* @param alias the alias to convert into a database path
*/
internal fun dataDbPath(appContext: Context, network: ZcashNetwork, alias: String): String =
private suspend fun dataDbPath(
appContext: Context,
network: ZcashNetwork,
alias: String
): String =
aliasToPath(appContext, network, alias, ZcashSdk.DB_DATA_NAME)
private fun aliasToPath(appContext: Context, network: ZcashNetwork, alias: String, dbFileName: String): String {
private suspend fun aliasToPath(
appContext: Context,
network: ZcashNetwork,
alias: String,
dbFileName: String
): String {
val parentDir: String =
appContext.getDatabasePath("unused.db").parentFile?.absolutePath
appContext.getDatabasePathSuspend("unused.db").parentFile?.absolutePath
?: throw InitializerException.DatabasePathException
val prefix = if (alias.endsWith('_')) alias else "${alias}_"
return File(parentDir, "$prefix${network.networkName}_$dbFileName").absolutePath
@ -405,9 +451,10 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* @param filePath the path of the db to erase.
* @return true when a file exists at the given path and was deleted.
*/
private fun deleteDb(filePath: String): Boolean {
private suspend fun deleteDb(filePath: String): Boolean {
// just try the journal file. Doesn't matter if it's not there.
delete("$filePath-journal")
return delete(filePath)
}
@ -417,14 +464,16 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
* @param filePath the path of the file to erase.
* @return true when a file exists at the given path and was deleted.
*/
private fun delete(filePath: String): Boolean {
private suspend fun delete(filePath: String): Boolean {
return File(filePath).let {
if (it.exists()) {
twig("Deleting ${it.name}!")
it.delete()
true
} else {
false
withContext(SdkDispatchers.IO) {
if (it.exists()) {
twig("Deleting ${it.name}!")
it.delete()
true
} else {
false
}
}
}
}
@ -445,9 +494,9 @@ class Initializer constructor(appContext: Context, onCriticalErrorHandler: ((Thr
internal fun validateAlias(alias: String) {
require(
alias.length in 1..99 && alias[0].isLetter() &&
alias.all { it.isLetterOrDigit() || it == '_' }
alias.all { it.isLetterOrDigit() || it == '_' }
) {
"ERROR: Invalid alias ($alias). For security, the alias must be shorter than 100 " +
"characters and only contain letters, digits or underscores and start with a letter."
"characters and only contain letters, digits or underscores and start with a letter."
}
}

View File

@ -58,7 +58,6 @@ import io.grpc.ManagedChannel
import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Dispatchers.IO
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.Job
@ -247,7 +246,7 @@ class SdkSynchronizer internal constructor(
override val latestBirthdayHeight: Int get() = processor.birthdayHeight
override fun prepare(): Synchronizer = apply {
override suspend fun prepare(): Synchronizer = apply {
storage.prepare()
}
@ -336,15 +335,15 @@ class SdkSynchronizer internal constructor(
// TODO: turn this section into the data access API. For now, just aggregate all the things that we want to do with the underlying data
fun findBlockHash(height: Int): ByteArray? {
suspend fun findBlockHash(height: Int): ByteArray? {
return (storage as? PagedTransactionRepository)?.findBlockHash(height)
}
fun findBlockHashAsHex(height: Int): String? {
suspend fun findBlockHashAsHex(height: Int): String? {
return findBlockHash(height)?.toHexReversed()
}
fun getTransactionCount(): Int {
suspend fun getTransactionCount(): Int {
return (storage as? PagedTransactionRepository)?.getTransactionCount() ?: 0
}
@ -530,7 +529,7 @@ class SdkSynchronizer internal constructor(
}
}
private suspend fun refreshPendingTransactions() = withContext(IO) {
private suspend fun refreshPendingTransactions() = withContext(Dispatchers.IO) {
twig("[cleanup] beginning to refresh and clean up pending transactions")
// TODO: this would be the place to clear out any stale pending transactions. Remove filter
// logic and then delete any pending transaction with sufficient confirmations (all in one
@ -737,7 +736,7 @@ class SdkSynchronizer internal constructor(
*
* @return true when content was found for the given alias. False otherwise.
*/
fun erase(appContext: Context, network: ZcashNetwork, alias: String = ZcashSdk.DEFAULT_ALIAS): Boolean
suspend fun erase(appContext: Context, network: ZcashNetwork, alias: String = ZcashSdk.DEFAULT_ALIAS): Boolean
}
}

View File

@ -35,7 +35,7 @@ interface Synchronizer {
* where setup and maintenance can occur for various Synchronizers. One that uses a database
* would take this opportunity to do data migrations or key migrations.
*/
fun prepare(): Synchronizer
suspend fun prepare(): Synchronizer
/**
* Starts this synchronizer within the given scope.

View File

@ -526,7 +526,7 @@ class CompactBlockProcessor(
* @return [ERROR_CODE_NONE] when there is no problem. Otherwise, return the lowest height where an error was
* found. In other words, validation starts at the back of the chain and works toward the tip.
*/
private fun validateNewBlocks(range: IntRange?): Int {
private suspend fun validateNewBlocks(range: IntRange?): Int {
if (range?.isEmpty() != false) {
twig("no blocks to validate: $range")
return ERROR_CODE_NONE

View File

@ -0,0 +1,14 @@
package cash.z.ecc.android.sdk.internal
import kotlinx.coroutines.asCoroutineDispatcher
import java.util.concurrent.Executors
internal object SdkDispatchers {
/*
* Based on internal discussion, keep the SDK internals confined to a single IO thread.
*
* We don't expect things to break, but we don't have the WAL enabled for SQLite so this
* is a simple solution.
*/
val IO = Executors.newSingleThreadExecutor().asCoroutineDispatcher()
}

View File

@ -6,8 +6,9 @@ import androidx.room.RoomDatabase
import cash.z.ecc.android.sdk.internal.db.CompactBlockDao
import cash.z.ecc.android.sdk.internal.db.CompactBlockDb
import cash.z.ecc.android.sdk.db.entity.CompactBlockEntity
import cash.z.ecc.android.sdk.internal.SdkDispatchers
import cash.z.wallet.sdk.rpc.CompactFormats
import kotlinx.coroutines.Dispatchers.IO
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
/**
@ -38,7 +39,7 @@ class CompactBlockDbStore(
.build()
}
override suspend fun getLatestHeight(): Int = withContext(IO) {
override suspend fun getLatestHeight(): Int = withContext(SdkDispatchers.IO) {
Math.max(0, cacheDao.latestBlockHeight())
}
@ -46,15 +47,17 @@ class CompactBlockDbStore(
return cacheDao.findCompactBlock(height)?.let { CompactFormats.CompactBlock.parseFrom(it) }
}
override suspend fun write(result: List<CompactFormats.CompactBlock>) = withContext(IO) {
override suspend fun write(result: List<CompactFormats.CompactBlock>) = withContext(SdkDispatchers.IO) {
cacheDao.insert(result.map { CompactBlockEntity(it.height.toInt(), it.toByteArray()) })
}
override suspend fun rewindTo(height: Int) = withContext(IO) {
override suspend fun rewindTo(height: Int) = withContext(SdkDispatchers.IO) {
cacheDao.rewindTo(height)
}
override fun close() {
cacheDb.close()
override suspend fun close() {
withContext(SdkDispatchers.IO) {
cacheDb.close()
}
}
}

View File

@ -7,6 +7,7 @@ import cash.z.ecc.android.sdk.internal.service.LightWalletService
import cash.z.wallet.sdk.rpc.Service
import io.grpc.StatusRuntimeException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Dispatchers.IO
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
@ -122,8 +123,10 @@ open class CompactBlockDownloader private constructor(val compactBlockStore: Com
/**
* Stop this downloader and cleanup any resources being used.
*/
fun stop() {
lightWalletService.shutdown()
suspend fun stop() {
withContext(Dispatchers.IO) {
lightWalletService.shutdown()
}
compactBlockStore.close()
}

View File

@ -37,5 +37,5 @@ interface CompactBlockStore {
/**
* Close any connections to the block store.
*/
fun close()
suspend fun close()
}

View File

@ -0,0 +1,11 @@
package cash.z.ecc.android.sdk.internal.ext
import android.content.Context
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
suspend fun Context.getDatabasePathSuspend(fileName: String) =
withContext(Dispatchers.IO) { getDatabasePath(fileName) }
suspend fun Context.getCacheDirSuspend() =
withContext(Dispatchers.IO) { cacheDir }

View File

@ -0,0 +1,7 @@
package cash.z.ecc.android.sdk.internal.ext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File
suspend fun File.deleteSuspend() = withContext(Dispatchers.IO) { delete() }

View File

@ -16,10 +16,12 @@ import cash.z.ecc.android.sdk.internal.ext.android.toFlowPagedList
import cash.z.ecc.android.sdk.internal.ext.android.toRefreshable
import cash.z.ecc.android.sdk.internal.ext.tryWarn
import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.internal.SdkDispatchers
import cash.z.ecc.android.sdk.jni.RustBackend
import cash.z.ecc.android.sdk.type.UnifiedAddressAccount
import cash.z.ecc.android.sdk.type.UnifiedViewingKey
import cash.z.ecc.android.sdk.type.WalletBirthday
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Dispatchers.IO
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.emitAll
@ -95,7 +97,7 @@ class PagedTransactionRepository(
override suspend fun getAccountCount(): Int = lazy.accounts.count()
override fun prepare() {
override suspend fun prepare() {
if (lazy.isPrepared.get()) {
twig("Warning: skipped the preparation step because we're already prepared!")
} else {
@ -112,7 +114,7 @@ class PagedTransactionRepository(
* side because Rust was intended to own the "dataDb" and Kotlin just reads from it. Since then,
* it has been more clear that Kotlin should own the data and just let Rust use it.
*/
private fun initMissingDatabases() {
private suspend fun initMissingDatabases() {
maybeCreateDataDb()
maybeInitBlocksTable(birthday)
maybeInitAccountsTable(viewingKeys)
@ -121,7 +123,7 @@ class PagedTransactionRepository(
/**
* Create the dataDb and its table, if it doesn't exist.
*/
private fun maybeCreateDataDb() {
private suspend fun maybeCreateDataDb() {
tryWarn("Warning: did not create dataDb. It probably already exists.") {
rustBackend.initDataDb()
twig("Initialized wallet for first run file: ${rustBackend.pathDataDb}")
@ -131,7 +133,7 @@ class PagedTransactionRepository(
/**
* Initialize the blocks table with the given birthday, if needed.
*/
private fun maybeInitBlocksTable(birthday: WalletBirthday) {
private suspend fun maybeInitBlocksTable(birthday: WalletBirthday) {
// TODO: consider converting these to typed exceptions in the welding layer
tryWarn(
"Warning: did not initialize the blocks table. It probably was already initialized.",
@ -151,7 +153,7 @@ class PagedTransactionRepository(
/**
* Initialize the accounts table with the given viewing keys.
*/
private fun maybeInitAccountsTable(viewingKeys: List<UnifiedViewingKey>) {
private suspend fun maybeInitAccountsTable(viewingKeys: List<UnifiedViewingKey>) {
// TODO: consider converting these to typed exceptions in the welding layer
tryWarn(
"Warning: did not initialize the accounts table. It probably was already initialized.",
@ -181,7 +183,7 @@ class PagedTransactionRepository(
}
}
private fun applyKeyMigrations() {
private suspend fun applyKeyMigrations() {
if (overwriteVks) {
twig("applying key migrations . . .")
maybeInitAccountsTable(viewingKeys)
@ -191,8 +193,10 @@ class PagedTransactionRepository(
/**
* Close the underlying database.
*/
fun close() {
lazy.db?.close()
suspend fun close() {
withContext(Dispatchers.IO) {
lazy.db?.close()
}
}
// TODO: begin converting these into Data Access API. For now, just collect the desired operations and iterate/refactor, later

View File

@ -87,7 +87,7 @@ interface TransactionRepository {
suspend fun getAccountCount(): Int
fun prepare()
suspend fun prepare()
//
// Transactions

View File

@ -8,7 +8,8 @@ import cash.z.ecc.android.sdk.internal.twigTask
import cash.z.ecc.android.sdk.jni.RustBackend
import cash.z.ecc.android.sdk.jni.RustBackendWelding
import cash.z.ecc.android.sdk.internal.SaplingParamTool
import kotlinx.coroutines.Dispatchers.IO
import cash.z.ecc.android.sdk.internal.SdkDispatchers
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
/**
@ -44,7 +45,7 @@ class WalletTransactionEncoder(
toAddress: String,
memo: ByteArray?,
fromAccountIndex: Int
): EncodedTransaction = withContext(IO) {
): EncodedTransaction = withContext(SdkDispatchers.IO) {
val transactionId = createSpend(spendingKey, zatoshi, toAddress, memo)
repository.findEncodedTransactionById(transactionId)
?: throw TransactionEncoderException.TransactionNotFoundException(transactionId)
@ -54,7 +55,7 @@ class WalletTransactionEncoder(
spendingKey: String,
transparentSecretKey: String,
memo: ByteArray?
): EncodedTransaction = withContext(IO) {
): EncodedTransaction = withContext(SdkDispatchers.IO) {
val transactionId = createShieldingSpend(spendingKey, transparentSecretKey, memo)
repository.findEncodedTransactionById(transactionId)
?: throw TransactionEncoderException.TransactionNotFoundException(transactionId)
@ -68,7 +69,7 @@ class WalletTransactionEncoder(
*
* @return true when the given address is a valid z-addr
*/
override suspend fun isValidShieldedAddress(address: String): Boolean = withContext(IO) {
override suspend fun isValidShieldedAddress(address: String): Boolean = withContext(Dispatchers.IO) {
rustBackend.isValidShieldedAddr(address)
}
@ -80,7 +81,7 @@ class WalletTransactionEncoder(
*
* @return true when the given address is a valid t-addr
*/
override suspend fun isValidTransparentAddress(address: String): Boolean = withContext(IO) {
override suspend fun isValidTransparentAddress(address: String): Boolean = withContext(Dispatchers.IO) {
rustBackend.isValidTransparentAddr(address)
}
@ -110,7 +111,7 @@ class WalletTransactionEncoder(
toAddress: String,
memo: ByteArray? = byteArrayOf(),
fromAccountIndex: Int = 0
): Long = withContext(IO) {
): Long = withContext(Dispatchers.IO) {
twigTask(
"creating transaction to spend $zatoshi zatoshi to" +
" ${toAddress.masked()} with memo $memo"
@ -140,7 +141,7 @@ class WalletTransactionEncoder(
spendingKey: String,
transparentSecretKey: String,
memo: ByteArray? = byteArrayOf()
): Long = withContext(IO) {
): Long = withContext(Dispatchers.IO) {
twigTask("creating transaction to shield all UTXOs") {
try {
SaplingParamTool.ensureParams((rustBackend as RustBackend).pathParamsDir)

View File

@ -0,0 +1,44 @@
package cash.z.ecc.android.sdk.jni
import cash.z.ecc.android.sdk.internal.twig
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import java.util.concurrent.atomic.AtomicBoolean
/**
* Loads a native library once. This class is thread-safe.
*
* To use this class, create a singleton instance for each given [libraryName].
*
* @param libraryName Name of the library to load.
*/
internal class NativeLibraryLoader(private val libraryName: String) {
private val isLoaded = AtomicBoolean(false)
private val mutex = Mutex()
suspend fun load() {
// Double-checked locking to avoid the Mutex unless necessary, as the hot path is
// for the library to be loaded since this should only run once for the lifetime
// of the application
if (!isLoaded.get()) {
mutex.withLock {
if (!isLoaded.get()) {
loadRustLibrary()
}
}
}
}
private suspend fun loadRustLibrary() {
try {
withContext(Dispatchers.IO) {
twig("Loading native library $libraryName") { System.loadLibrary(libraryName) }
}
isLoaded.set(true)
} catch (e: Throwable) {
twig("Error while loading native library: ${e.message}")
}
}
}

View File

@ -4,10 +4,14 @@ import cash.z.ecc.android.sdk.exception.BirthdayException
import cash.z.ecc.android.sdk.ext.ZcashSdk.OUTPUT_PARAM_FILE_NAME
import cash.z.ecc.android.sdk.ext.ZcashSdk.SPEND_PARAM_FILE_NAME
import cash.z.ecc.android.sdk.internal.twig
import cash.z.ecc.android.sdk.internal.SdkDispatchers
import cash.z.ecc.android.sdk.internal.ext.deleteSuspend
import cash.z.ecc.android.sdk.tool.DerivationTool
import cash.z.ecc.android.sdk.type.UnifiedViewingKey
import cash.z.ecc.android.sdk.type.WalletBalance
import cash.z.ecc.android.sdk.type.ZcashNetwork
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File
/**
@ -17,10 +21,6 @@ import java.io.File
*/
class RustBackend private constructor() : RustBackendWelding {
init {
load()
}
// Paths
lateinit var pathDataDb: String
internal set
@ -35,14 +35,14 @@ class RustBackend private constructor() : RustBackendWelding {
get() = if (field != -1) field else throw BirthdayException.UninitializedBirthdayException
private set
fun clear(clearCacheDb: Boolean = true, clearDataDb: Boolean = true) {
suspend fun clear(clearCacheDb: Boolean = true, clearDataDb: Boolean = true) {
if (clearCacheDb) {
twig("Deleting the cache database!")
File(pathCacheDb).delete()
File(pathCacheDb).deleteSuspend()
}
if (clearDataDb) {
twig("Deleting the data database!")
File(pathDataDb).delete()
File(pathDataDb).deleteSuspend()
}
}
@ -50,19 +50,31 @@ class RustBackend private constructor() : RustBackendWelding {
// Wrapper Functions
//
override fun initDataDb() = initDataDb(pathDataDb, networkId = network.id)
override suspend fun initDataDb() = withContext(SdkDispatchers.IO) {
initDataDb(
pathDataDb,
networkId = network.id
)
}
override fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean {
override suspend fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean {
val extfvks = Array(keys.size) { "" }
val extpubs = Array(keys.size) { "" }
keys.forEachIndexed { i, key ->
extfvks[i] = key.extfvk
extpubs[i] = key.extpub
}
return initAccountsTableWithKeys(pathDataDb, extfvks, extpubs, networkId = network.id)
return withContext(SdkDispatchers.IO) {
initAccountsTableWithKeys(
pathDataDb,
extfvks,
extpubs,
networkId = network.id
)
}
}
override fun initAccountsTable(
override suspend fun initAccountsTable(
seed: ByteArray,
numberOfAccounts: Int
): Array<UnifiedViewingKey> {
@ -71,82 +83,131 @@ class RustBackend private constructor() : RustBackendWelding {
}
}
override fun initBlocksTable(
override suspend fun initBlocksTable(
height: Int,
hash: String,
time: Long,
saplingTree: String
): Boolean {
return initBlocksTable(pathDataDb, height, hash, time, saplingTree, networkId = network.id)
return withContext(SdkDispatchers.IO) {
initBlocksTable(
pathDataDb,
height,
hash,
time,
saplingTree,
networkId = network.id
)
}
}
override fun getShieldedAddress(account: Int) = getShieldedAddress(pathDataDb, account, networkId = network.id)
override suspend fun getShieldedAddress(account: Int) = withContext(SdkDispatchers.IO) {
getShieldedAddress(
pathDataDb,
account,
networkId = network.id
)
}
override fun getTransparentAddress(account: Int, index: Int): String {
override suspend fun getTransparentAddress(account: Int, index: Int): String {
throw NotImplementedError("TODO: implement this at the zcash_client_sqlite level. But for now, use DerivationTool, instead to derive addresses from seeds")
}
override fun getBalance(account: Int) = getBalance(pathDataDb, account, networkId = network.id)
override suspend fun getBalance(account: Int) = withContext(SdkDispatchers.IO) {
getBalance(
pathDataDb,
account,
networkId = network.id
)
}
override fun getVerifiedBalance(account: Int) = getVerifiedBalance(pathDataDb, account, networkId = network.id)
override suspend fun getVerifiedBalance(account: Int) = withContext(SdkDispatchers.IO) {
getVerifiedBalance(
pathDataDb,
account,
networkId = network.id
)
}
override fun getReceivedMemoAsUtf8(idNote: Long) =
getReceivedMemoAsUtf8(pathDataDb, idNote, networkId = network.id)
override suspend fun getReceivedMemoAsUtf8(idNote: Long) =
withContext(SdkDispatchers.IO) { getReceivedMemoAsUtf8(pathDataDb, idNote, networkId = network.id) }
override fun getSentMemoAsUtf8(idNote: Long) = getSentMemoAsUtf8(pathDataDb, idNote, networkId = network.id)
override suspend fun getSentMemoAsUtf8(idNote: Long) = withContext(SdkDispatchers.IO) {
getSentMemoAsUtf8(
pathDataDb,
idNote,
networkId = network.id
)
}
override fun validateCombinedChain() = validateCombinedChain(pathCacheDb, pathDataDb, networkId = network.id,)
override suspend fun validateCombinedChain() = withContext(SdkDispatchers.IO) {
validateCombinedChain(
pathCacheDb,
pathDataDb,
networkId = network.id,
)
}
override fun getNearestRewindHeight(height: Int): Int = getNearestRewindHeight(pathDataDb, height, networkId = network.id)
override suspend fun getNearestRewindHeight(height: Int): Int = withContext(SdkDispatchers.IO) {
getNearestRewindHeight(
pathDataDb,
height,
networkId = network.id
)
}
/**
* Deletes data for all blocks above the given height. Boils down to:
*
* DELETE FROM blocks WHERE height > ?
*/
override fun rewindToHeight(height: Int) = rewindToHeight(pathDataDb, height, networkId = network.id)
override suspend fun rewindToHeight(height: Int) =
withContext(SdkDispatchers.IO) { rewindToHeight(pathDataDb, height, networkId = network.id) }
override fun scanBlocks(limit: Int): Boolean {
override suspend fun scanBlocks(limit: Int): Boolean {
return if (limit > 0) {
scanBlockBatch(pathCacheDb, pathDataDb, limit, networkId = network.id)
withContext(SdkDispatchers.IO) {
scanBlockBatch(
pathCacheDb,
pathDataDb,
limit,
networkId = network.id
)
}
} else {
scanBlocks(pathCacheDb, pathDataDb, networkId = network.id)
withContext(SdkDispatchers.IO) {
scanBlocks(
pathCacheDb,
pathDataDb,
networkId = network.id
)
}
}
}
override fun decryptAndStoreTransaction(tx: ByteArray) = decryptAndStoreTransaction(pathDataDb, tx, networkId = network.id)
override suspend fun decryptAndStoreTransaction(tx: ByteArray) = withContext(SdkDispatchers.IO) {
decryptAndStoreTransaction(
pathDataDb,
tx,
networkId = network.id
)
}
override fun createToAddress(
override suspend fun createToAddress(
consensusBranchId: Long,
account: Int,
extsk: String,
to: String,
value: Long,
memo: ByteArray?
): Long = createToAddress(
pathDataDb,
consensusBranchId,
account,
extsk,
to,
value,
memo ?: ByteArray(0),
"$pathParamsDir/$SPEND_PARAM_FILE_NAME",
"$pathParamsDir/$OUTPUT_PARAM_FILE_NAME",
networkId = network.id,
)
override fun shieldToAddress(
extsk: String,
tsk: String,
memo: ByteArray?
): Long {
twig("TMP: shieldToAddress with db path: $pathDataDb, ${memo?.size}")
return shieldToAddress(
): Long = withContext(SdkDispatchers.IO) {
createToAddress(
pathDataDb,
0,
consensusBranchId,
account,
extsk,
tsk,
to,
value,
memo ?: ByteArray(0),
"$pathParamsDir/$SPEND_PARAM_FILE_NAME",
"$pathParamsDir/$OUTPUT_PARAM_FILE_NAME",
@ -154,31 +215,84 @@ class RustBackend private constructor() : RustBackendWelding {
)
}
override fun putUtxo(
override suspend fun shieldToAddress(
extsk: String,
tsk: String,
memo: ByteArray?
): Long {
twig("TMP: shieldToAddress with db path: $pathDataDb, ${memo?.size}")
return withContext(SdkDispatchers.IO) {
shieldToAddress(
pathDataDb,
0,
extsk,
tsk,
memo ?: ByteArray(0),
"$pathParamsDir/$SPEND_PARAM_FILE_NAME",
"$pathParamsDir/$OUTPUT_PARAM_FILE_NAME",
networkId = network.id,
)
}
}
override suspend fun putUtxo(
tAddress: String,
txId: ByteArray,
index: Int,
script: ByteArray,
value: Long,
height: Int
): Boolean = putUtxo(pathDataDb, tAddress, txId, index, script, value, height, networkId = network.id)
): Boolean = withContext(SdkDispatchers.IO) {
putUtxo(
pathDataDb,
tAddress,
txId,
index,
script,
value,
height,
networkId = network.id
)
}
override fun clearUtxos(
override suspend fun clearUtxos(
tAddress: String,
aboveHeight: Int,
): Boolean = clearUtxos(pathDataDb, tAddress, aboveHeight, networkId = network.id)
): Boolean = withContext(SdkDispatchers.IO) {
clearUtxos(
pathDataDb,
tAddress,
aboveHeight,
networkId = network.id
)
}
override fun getDownloadedUtxoBalance(address: String): WalletBalance {
val verified = getVerifiedTransparentBalance(pathDataDb, address, networkId = network.id)
val total = getTotalTransparentBalance(pathDataDb, address, networkId = network.id)
override suspend fun getDownloadedUtxoBalance(address: String): WalletBalance {
val verified = withContext(SdkDispatchers.IO) {
getVerifiedTransparentBalance(
pathDataDb,
address,
networkId = network.id
)
}
val total = withContext(SdkDispatchers.IO) {
getTotalTransparentBalance(
pathDataDb,
address,
networkId = network.id
)
}
return WalletBalance(total, verified)
}
override fun isValidShieldedAddr(addr: String) = isValidShieldedAddress(addr, networkId = network.id)
override fun isValidShieldedAddr(addr: String) =
isValidShieldedAddress(addr, networkId = network.id)
override fun isValidTransparentAddr(addr: String) = isValidTransparentAddress(addr, networkId = network.id)
override fun isValidTransparentAddr(addr: String) =
isValidTransparentAddress(addr, networkId = network.id)
override fun getBranchIdForHeight(height: Int): Long = branchIdForHeight(height, networkId = network.id)
override fun getBranchIdForHeight(height: Int): Long =
branchIdForHeight(height, networkId = network.id)
// /**
// * This is a proof-of-concept for doing Local RPC, where we are effectively using the JNI
@ -203,19 +317,21 @@ class RustBackend private constructor() : RustBackendWelding {
* Exposes all of the librustzcash functions along with helpers for loading the static library.
*/
companion object {
private var loaded = false
internal val rustLibraryLoader = NativeLibraryLoader("zcashwalletsdk")
/**
* Loads the library and initializes path variables. Although it is best to only call this
* function once, it is idempotent.
*/
fun init(
suspend fun init(
cacheDbPath: String,
dataDbPath: String,
paramsPath: String,
zcashNetwork: ZcashNetwork,
birthdayHeight: Int? = null
): RustBackend {
rustLibraryLoader.load()
return RustBackend().apply {
pathCacheDb = cacheDbPath
pathDataDb = dataDbPath
@ -227,16 +343,6 @@ class RustBackend private constructor() : RustBackendWelding {
}
}
fun load() {
// It is safe to call these things twice but not efficient. So we add a loose check and
// ignore the fact that it's not thread-safe.
if (!loaded) {
twig("Loading RustBackend") {
loadRustLibrary()
}
}
}
/**
* Forwards Rust logs to logcat. This is a function that is intended for debug purposes. All
* logs will be tagged with `cash.z.rust.logs`. Typically, a developer would clone
@ -249,33 +355,24 @@ class RustBackend private constructor() : RustBackendWelding {
*/
fun enableRustLogs() = initLogs()
/**
* The first call made to this object in order to load the Rust backend library. All other
* external function calls will fail if the libraries have not been loaded.
*/
private fun loadRustLibrary() {
try {
System.loadLibrary("zcashwalletsdk")
loaded = true
} catch (e: Throwable) {
twig("Error while loading native library: ${e.message}")
}
}
//
// External Functions
//
@JvmStatic private external fun initDataDb(dbDataPath: String, networkId: Int): Boolean
@JvmStatic
private external fun initDataDb(dbDataPath: String, networkId: Int): Boolean
@JvmStatic private external fun initAccountsTableWithKeys(
@JvmStatic
private external fun initAccountsTableWithKeys(
dbDataPath: String,
extfvk: Array<out String>,
extpub: Array<out String>,
networkId: Int,
): Boolean
@JvmStatic private external fun initBlocksTable(
@JvmStatic
private external fun initBlocksTable(
dbDataPath: String,
height: Int,
hash: String,
@ -364,7 +461,8 @@ class RustBackend private constructor() : RustBackendWelding {
networkId: Int,
)
@JvmStatic private external fun createToAddress(
@JvmStatic
private external fun createToAddress(
dbDataPath: String,
consensusBranchId: Long,
account: Int,
@ -377,7 +475,8 @@ class RustBackend private constructor() : RustBackendWelding {
networkId: Int,
): Long
@JvmStatic private external fun shieldToAddress(
@JvmStatic
private external fun shieldToAddress(
dbDataPath: String,
account: Int,
extsk: String,
@ -388,11 +487,14 @@ class RustBackend private constructor() : RustBackendWelding {
networkId: Int,
): Long
@JvmStatic private external fun initLogs()
@JvmStatic
private external fun initLogs()
@JvmStatic private external fun branchIdForHeight(height: Int, networkId: Int): Long
@JvmStatic
private external fun branchIdForHeight(height: Int, networkId: Int): Long
@JvmStatic private external fun putUtxo(
@JvmStatic
private external fun putUtxo(
dbDataPath: String,
tAddress: String,
txId: ByteArray,
@ -403,23 +505,27 @@ class RustBackend private constructor() : RustBackendWelding {
networkId: Int,
): Boolean
@JvmStatic private external fun clearUtxos(
@JvmStatic
private external fun clearUtxos(
dbDataPath: String,
tAddress: String,
aboveHeight: Int,
networkId: Int,
): Boolean
@JvmStatic private external fun getVerifiedTransparentBalance(
@JvmStatic
private external fun getVerifiedTransparentBalance(
pathDataDb: String,
taddr: String,
networkId: Int,
): Long
@JvmStatic private external fun getTotalTransparentBalance(
@JvmStatic
private external fun getTotalTransparentBalance(
pathDataDb: String,
taddr: String,
networkId: Int,
): Long
}
}

View File

@ -14,7 +14,7 @@ interface RustBackendWelding {
val network: ZcashNetwork
fun createToAddress(
suspend fun createToAddress(
consensusBranchId: Long,
account: Int,
extsk: String,
@ -23,51 +23,51 @@ interface RustBackendWelding {
memo: ByteArray? = byteArrayOf()
): Long
fun shieldToAddress(
suspend fun shieldToAddress(
extsk: String,
tsk: String,
memo: ByteArray? = byteArrayOf()
): Long
fun decryptAndStoreTransaction(tx: ByteArray)
suspend fun decryptAndStoreTransaction(tx: ByteArray)
fun initAccountsTable(seed: ByteArray, numberOfAccounts: Int): Array<UnifiedViewingKey>
suspend fun initAccountsTable(seed: ByteArray, numberOfAccounts: Int): Array<UnifiedViewingKey>
fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean
suspend fun initAccountsTable(vararg keys: UnifiedViewingKey): Boolean
fun initBlocksTable(height: Int, hash: String, time: Long, saplingTree: String): Boolean
suspend fun initBlocksTable(height: Int, hash: String, time: Long, saplingTree: String): Boolean
fun initDataDb(): Boolean
suspend fun initDataDb(): Boolean
fun isValidShieldedAddr(addr: String): Boolean
fun isValidTransparentAddr(addr: String): Boolean
fun getShieldedAddress(account: Int = 0): String
suspend fun getShieldedAddress(account: Int = 0): String
fun getTransparentAddress(account: Int = 0, index: Int = 0): String
suspend fun getTransparentAddress(account: Int = 0, index: Int = 0): String
fun getBalance(account: Int = 0): Long
suspend fun getBalance(account: Int = 0): Long
fun getBranchIdForHeight(height: Int): Long
fun getReceivedMemoAsUtf8(idNote: Long): String
suspend fun getReceivedMemoAsUtf8(idNote: Long): String
fun getSentMemoAsUtf8(idNote: Long): String
suspend fun getSentMemoAsUtf8(idNote: Long): String
fun getVerifiedBalance(account: Int = 0): Long
suspend fun getVerifiedBalance(account: Int = 0): Long
// fun parseTransactionDataList(tdl: LocalRpcTypes.TransactionDataList): LocalRpcTypes.TransparentTransactionList
fun getNearestRewindHeight(height: Int): Int
suspend fun getNearestRewindHeight(height: Int): Int
fun rewindToHeight(height: Int): Boolean
suspend fun rewindToHeight(height: Int): Boolean
fun scanBlocks(limit: Int = -1): Boolean
suspend fun scanBlocks(limit: Int = -1): Boolean
fun validateCombinedChain(): Int
suspend fun validateCombinedChain(): Int
fun putUtxo(
suspend fun putUtxo(
tAddress: String,
txId: ByteArray,
index: Int,
@ -76,59 +76,59 @@ interface RustBackendWelding {
height: Int
): Boolean
fun clearUtxos(tAddress: String, aboveHeight: Int = network.saplingActivationHeight - 1): Boolean
suspend fun clearUtxos(tAddress: String, aboveHeight: Int = network.saplingActivationHeight - 1): Boolean
fun getDownloadedUtxoBalance(address: String): WalletBalance
suspend fun getDownloadedUtxoBalance(address: String): WalletBalance
// Implemented by `DerivationTool`
interface Derivation {
fun deriveShieldedAddress(
suspend fun deriveShieldedAddress(
viewingKey: String,
network: ZcashNetwork
): String
fun deriveShieldedAddress(
suspend fun deriveShieldedAddress(
seed: ByteArray,
network: ZcashNetwork,
accountIndex: Int = 0,
): String
fun deriveSpendingKeys(
suspend fun deriveSpendingKeys(
seed: ByteArray,
network: ZcashNetwork,
numberOfAccounts: Int = 1,
): Array<String>
fun deriveTransparentAddress(
suspend fun deriveTransparentAddress(
seed: ByteArray,
network: ZcashNetwork,
account: Int = 0,
index: Int = 0,
): String
fun deriveTransparentAddressFromPublicKey(
suspend fun deriveTransparentAddressFromPublicKey(
publicKey: String,
network: ZcashNetwork
): String
fun deriveTransparentAddressFromPrivateKey(
suspend fun deriveTransparentAddressFromPrivateKey(
privateKey: String,
network: ZcashNetwork
): String
fun deriveTransparentSecretKey(
suspend fun deriveTransparentSecretKey(
seed: ByteArray,
network: ZcashNetwork,
account: Int = 0,
index: Int = 0,
): String
fun deriveViewingKey(
suspend fun deriveViewingKey(
spendingKey: String,
network: ZcashNetwork
): String
fun deriveUnifiedViewingKeys(
suspend fun deriveUnifiedViewingKeys(
seed: ByteArray,
network: ZcashNetwork,
numberOfAccounts: Int = 1,

View File

@ -18,7 +18,7 @@ class DerivationTool {
*
* @return the viewing keys that correspond to the seed, formatted as Strings.
*/
override fun deriveUnifiedViewingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array<UnifiedViewingKey> =
override suspend fun deriveUnifiedViewingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array<UnifiedViewingKey> =
withRustBackendLoaded {
deriveUnifiedViewingKeysFromSeed(seed, numberOfAccounts, networkId = network.id).map {
UnifiedViewingKey(it[0], it[1])
@ -32,7 +32,7 @@ class DerivationTool {
*
* @return the viewing key that corresponds to the spending key.
*/
override fun deriveViewingKey(spendingKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
override suspend fun deriveViewingKey(spendingKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
deriveExtendedFullViewingKey(spendingKey, networkId = network.id)
}
@ -45,7 +45,7 @@ class DerivationTool {
*
* @return the spending keys that correspond to the seed, formatted as Strings.
*/
override fun deriveSpendingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array<String> =
override suspend fun deriveSpendingKeys(seed: ByteArray, network: ZcashNetwork, numberOfAccounts: Int): Array<String> =
withRustBackendLoaded {
deriveExtendedSpendingKeys(seed, numberOfAccounts, networkId = network.id)
}
@ -59,7 +59,7 @@ class DerivationTool {
*
* @return the address that corresponds to the seed and account index.
*/
override fun deriveShieldedAddress(seed: ByteArray, network: ZcashNetwork, accountIndex: Int): String =
override suspend fun deriveShieldedAddress(seed: ByteArray, network: ZcashNetwork, accountIndex: Int): String =
withRustBackendLoaded {
deriveShieldedAddressFromSeed(seed, accountIndex, networkId = network.id)
}
@ -72,26 +72,26 @@ class DerivationTool {
*
* @return the address that corresponds to the viewing key.
*/
override fun deriveShieldedAddress(viewingKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
override suspend fun deriveShieldedAddress(viewingKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
deriveShieldedAddressFromViewingKey(viewingKey, networkId = network.id)
}
// WIP probably shouldn't be used just yet. Why?
// - because we need the private key associated with this seed and this function doesn't return it.
// - the underlying implementation needs to be split out into a few lower-level calls
override fun deriveTransparentAddress(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded {
override suspend fun deriveTransparentAddress(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded {
deriveTransparentAddressFromSeed(seed, account, index, networkId = network.id)
}
override fun deriveTransparentAddressFromPublicKey(transparentPublicKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
override suspend fun deriveTransparentAddressFromPublicKey(transparentPublicKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
deriveTransparentAddressFromPubKey(transparentPublicKey, networkId = network.id)
}
override fun deriveTransparentAddressFromPrivateKey(transparentPrivateKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
override suspend fun deriveTransparentAddressFromPrivateKey(transparentPrivateKey: String, network: ZcashNetwork): String = withRustBackendLoaded {
deriveTransparentAddressFromPrivKey(transparentPrivateKey, networkId = network.id)
}
override fun deriveTransparentSecretKey(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded {
override suspend fun deriveTransparentSecretKey(seed: ByteArray, network: ZcashNetwork, account: Int, index: Int): String = withRustBackendLoaded {
deriveTransparentSecretKeyFromSeed(seed, account, index, networkId = network.id)
}
@ -104,8 +104,8 @@ class DerivationTool {
* class attempts to interact with it, indirectly, by invoking JNI functions. It would be
* nice to have an annotation like @UsesSystemLibrary for this
*/
private fun <T> withRustBackendLoaded(block: () -> T): T {
RustBackend.load()
private suspend fun <T> withRustBackendLoaded(block: () -> T): T {
RustBackend.rustLibraryLoader.load()
return block()
}

View File

@ -8,165 +8,156 @@ import cash.z.ecc.android.sdk.type.WalletBirthday
import cash.z.ecc.android.sdk.type.ZcashNetwork
import com.google.gson.Gson
import com.google.gson.stream.JsonReader
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.IOException
import java.io.InputStreamReader
import java.util.Locale
/**
* Tool for loading checkpoints for the wallet, based on the height at which the wallet was born.
*
* @param appContext needed for loading checkpoints from the app's assets directory.
*/
class WalletBirthdayTool(appContext: Context) {
private val context = appContext.applicationContext
object WalletBirthdayTool {
// Behavior change implemented as a fix for issue #270. Temporarily adding a boolean
// that allows the change to be rolled back quickly if needed, although long-term
// this flag should be removed.
@VisibleForTesting
internal val IS_FALLBACK_ON_FAILURE = true
/**
* Load the nearest checkpoint to the given birthday height. If null is given, then this
* will load the most recent checkpoint available.
*/
fun loadNearest(network: ZcashNetwork, birthdayHeight: Int? = null): WalletBirthday {
suspend fun loadNearest(
context: Context,
network: ZcashNetwork,
birthdayHeight: Int? = null
): WalletBirthday {
// TODO: potentially pull from shared preferences first
return loadBirthdayFromAssets(context, network, birthdayHeight)
}
companion object {
// Behavior change implemented as a fix for issue #270. Temporarily adding a boolean
// that allows the change to be rolled back quickly if needed, although long-term
// this flag should be removed.
@VisibleForTesting
internal val IS_FALLBACK_ON_FAILURE = true
/**
* Load the nearest checkpoint to the given birthday height. If null is given, then this
* will load the most recent checkpoint available.
*/
fun loadNearest(
context: Context,
network: ZcashNetwork,
birthdayHeight: Int? = null
): WalletBirthday {
// TODO: potentially pull from shared preferences first
return loadBirthdayFromAssets(context, network, birthdayHeight)
}
/**
* Useful for when an exact checkpoint is needed, like for SAPLING_ACTIVATION_HEIGHT. In
* most cases, loading the nearest checkpoint is preferred for privacy reasons.
*/
fun loadExact(context: Context, network: ZcashNetwork, birthdayHeight: Int) =
loadNearest(context, network, birthdayHeight).also {
if (it.height != birthdayHeight)
throw BirthdayException.ExactBirthdayNotFoundException(
birthdayHeight,
it.height
)
}
// TODO: This method performs disk IO; convert to suspending function
// Converting this to suspending will then propagate
@Throws(IOException::class)
internal fun listBirthdayDirectoryContents(context: Context, directory: String) =
context.assets.list(directory)
/**
* Returns the directory within the assets folder where birthday data
* (i.e. sapling trees for a given height) can be found.
*/
@VisibleForTesting
internal fun birthdayDirectory(network: ZcashNetwork) =
"saplingtree/${(network.networkName as java.lang.String).toLowerCase(Locale.US)}"
internal fun birthdayHeight(fileName: String) = fileName.split('.').first().toInt()
private fun Array<String>.sortDescending() =
apply { sortByDescending { birthdayHeight(it) } }
/**
* Load the given birthday file from the assets of the given context. When no height is
* specified, we default to the file with the greatest name.
*
* @param context the context from which to load assets.
* @param birthdayHeight the height file to look for among the file names.
*
* @return a WalletBirthday that reflects the contents of the file or an exception when
* parsing fails.
*/
private fun loadBirthdayFromAssets(
context: Context,
network: ZcashNetwork,
birthdayHeight: Int? = null
): WalletBirthday {
twig("loading birthday from assets: $birthdayHeight")
val directory = birthdayDirectory(network)
val treeFiles = getFilteredFileNames(context, directory, birthdayHeight)
twig("found ${treeFiles.size} sapling tree checkpoints: $treeFiles")
return getFirstValidWalletBirthday(context, directory, treeFiles)
}
private fun getFilteredFileNames(
context: Context,
directory: String,
birthdayHeight: Int? = null,
): List<String> {
val unfilteredTreeFiles = listBirthdayDirectoryContents(context, directory)
if (unfilteredTreeFiles.isNullOrEmpty()) {
throw BirthdayException.MissingBirthdayFilesException(directory)
}
val filteredTreeFiles = unfilteredTreeFiles
.sortDescending()
.filter { filename ->
birthdayHeight?.let { birthdayHeight(filename) <= it } ?: true
}
if (filteredTreeFiles.isEmpty()) {
throw BirthdayException.BirthdayFileNotFoundException(
directory,
birthdayHeight
/**
* Useful for when an exact checkpoint is needed, like for SAPLING_ACTIVATION_HEIGHT. In
* most cases, loading the nearest checkpoint is preferred for privacy reasons.
*/
suspend fun loadExact(context: Context, network: ZcashNetwork, birthdayHeight: Int) =
loadNearest(context, network, birthdayHeight).also {
if (it.height != birthdayHeight)
throw BirthdayException.ExactBirthdayNotFoundException(
birthdayHeight,
it.height
)
}
return filteredTreeFiles
}
/**
* @param treeFiles A list of files, sorted in descending order based on `int` value of the first part of the filename.
*/
@VisibleForTesting
internal fun getFirstValidWalletBirthday(
context: Context,
directory: String,
treeFiles: List<String>
): WalletBirthday {
var lastException: Exception? = null
treeFiles.forEach { treefile ->
try {
// Converting this to suspending will then propagate
@Throws(IOException::class)
internal suspend fun listBirthdayDirectoryContents(context: Context, directory: String) =
withContext(Dispatchers.IO) {
context.assets.list(directory)
}
/**
* Returns the directory within the assets folder where birthday data
* (i.e. sapling trees for a given height) can be found.
*/
@VisibleForTesting
internal fun birthdayDirectory(network: ZcashNetwork) =
"saplingtree/${(network.networkName as java.lang.String).toLowerCase(Locale.US)}"
internal fun birthdayHeight(fileName: String) = fileName.split('.').first().toInt()
private fun Array<String>.sortDescending() =
apply { sortByDescending { birthdayHeight(it) } }
/**
* Load the given birthday file from the assets of the given context. When no height is
* specified, we default to the file with the greatest name.
*
* @param context the context from which to load assets.
* @param birthdayHeight the height file to look for among the file names.
*
* @return a WalletBirthday that reflects the contents of the file or an exception when
* parsing fails.
*/
private suspend fun loadBirthdayFromAssets(
context: Context,
network: ZcashNetwork,
birthdayHeight: Int? = null
): WalletBirthday {
twig("loading birthday from assets: $birthdayHeight")
val directory = birthdayDirectory(network)
val treeFiles = getFilteredFileNames(context, directory, birthdayHeight)
twig("found ${treeFiles.size} sapling tree checkpoints: $treeFiles")
return getFirstValidWalletBirthday(context, directory, treeFiles)
}
private suspend fun getFilteredFileNames(
context: Context,
directory: String,
birthdayHeight: Int? = null,
): List<String> {
val unfilteredTreeFiles = listBirthdayDirectoryContents(context, directory)
if (unfilteredTreeFiles.isNullOrEmpty()) {
throw BirthdayException.MissingBirthdayFilesException(directory)
}
val filteredTreeFiles = unfilteredTreeFiles
.sortDescending()
.filter { filename ->
birthdayHeight?.let { birthdayHeight(filename) <= it } ?: true
}
if (filteredTreeFiles.isEmpty()) {
throw BirthdayException.BirthdayFileNotFoundException(
directory,
birthdayHeight
)
}
return filteredTreeFiles
}
/**
* @param treeFiles A list of files, sorted in descending order based on `int` value of the first part of the filename.
*/
@VisibleForTesting
internal suspend fun getFirstValidWalletBirthday(
context: Context,
directory: String,
treeFiles: List<String>
): WalletBirthday {
var lastException: Exception? = null
treeFiles.forEach { treefile ->
try {
return withContext(Dispatchers.IO) {
context.assets.open("$directory/$treefile").use { inputStream ->
InputStreamReader(inputStream).use { inputStreamReader ->
JsonReader(inputStreamReader).use { jsonReader ->
return Gson().fromJson(jsonReader, WalletBirthday::class.java)
Gson().fromJson(jsonReader, WalletBirthday::class.java)
}
}
}
} catch (t: Throwable) {
val exception = BirthdayException.MalformattedBirthdayFilesException(
directory,
treefile
)
lastException = exception
}
} catch (t: Throwable) {
val exception = BirthdayException.MalformattedBirthdayFilesException(
directory,
treefile
)
lastException = exception
if (IS_FALLBACK_ON_FAILURE) {
// TODO: If we ever add crash analytics hooks, this would be something to report
twig("Malformed birthday file $t")
} else {
throw exception
}
if (IS_FALLBACK_ON_FAILURE) {
// TODO: If we ever add crash analytics hooks, this would be something to report
twig("Malformed birthday file $t")
} else {
throw exception
}
}
throw lastException!!
}
throw lastException!!
}
}