diff --git a/rust/src/lib.rs b/rust/src/lib.rs index b6b590b9..c6a8f7a0 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -42,8 +42,6 @@ use zcash_primitives::consensus::{TestNetwork, TEST_NETWORK}; use zcash_proofs::prover::LocalTxProver; -use zcash_primitives::consensus::{TestNetwork, TEST_NETWORK}; - use std::convert::TryFrom; // ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -76,38 +74,32 @@ where } } +#[cfg(feature = "mainnet")] +pub const NETWORK: MainNetwork = MAIN_NETWORK; + +#[cfg(not(feature = "mainnet"))] +pub const NETWORK: TestNetwork = TEST_NETWORK; + + fn wallet_db(db_data: *const u8, db_data_len: usize) -> Result { - let res = catch_panic(|| { - Path::new(OsStr::from_bytes(unsafe { + let db_data = Path::new(OsStr::from_bytes(unsafe { slice::from_raw_parts(db_data, db_data_len) - })) - - }); - - match res { - Ok(value) => WalletDB::for_path(value) - .map_err(|e| format_err!("Error opening wallet database connection: {}", e)), - Err(e) => e - } + })); + WalletDB::for_path(value) + .map_err(|e| format_err!("Error opening wallet database connection: {}", e)) } fn block_db(cache_db: *const u8, cache_db_len: usize) -> Result { - - let res = catch_panic(|| { - Path::new(OsStr::from_bytes(unsafe { - slice::from_raw_parts(db_data, db_data_len) - })) - - }); - - match res { - Ok(value) => BlockDB::for_path(value) - .map_err(|e| format_err!("Error opening block source database connection: {}", e)), - Err(e) => e - } + + let cache_db = Path::new(OsStr::from_bytes(unsafe { + slice::from_raw_parts(cache_db, cache_db_len) + })); + BlockDB::for_path(cache_db) + .map_err(|e| format_err!("Error opening block source database connection: {}", e)) + } /// Returns the length of the last error message to be logged. @@ -168,7 +160,7 @@ pub extern "C" fn zcashlc_init_accounts_table( }; let extsks: Vec<_> = (0..accounts) - .map(|account| spending_key(&seed, COIN_TYPE, account)) + .map(|account| spending_key(&seed, NETWORK.coin_type(), account)) .collect(); let extfvks: Vec<_> = extsks.iter().map(ExtendedFullViewingKey::from).collect(); @@ -205,28 +197,20 @@ pub extern "C" fn zcashlc_init_accounts_table_with_keys( extfvks_len: usize, ) -> bool { let res = catch_panic(|| { - let db_data = Path::new(OsStr::from_bytes(unsafe { - slice::from_raw_parts(db_data, db_data_len) - })); + let db_data = wallet_db(db_data, db_data_len)?; - let extfvks = unsafe { std::slice::from_raw_parts(extfvks, extfvks_len) - .into_iter() - .map(|s| CStr::from_ptr(*s).to_str().unwrap()) - .map( |vkstr| - decode_extended_full_viewing_key(HRP_SAPLING_EXTENDED_FULL_VIEWING_KEY, &vkstr) - .unwrap() - .unwrap() - ).collect::>() }; + let extfvks = unsafe { std::slice::from_raw_parts(extfvks, extfvks_len) + .into_iter() + .map(|s| CStr::from_ptr(*s).to_str().unwrap()) + .map( |vkstr| + decode_extended_full_viewing_key(NETWORK.hrp_sapling_extended_full_viewing_key(), &vkstr) + .unwrap() + .unwrap() + ).collect::>() }; - match init_accounts_table(&db_data,&extfvks) { + match init_accounts_table(&db_data, &NETWORK, &extfvks) { Ok(()) => Ok(true), - Err(e) => match e.kind() { - ErrorKind::TableNotEmpty => { - // Ignore this error. - Ok(true) - } - _ => return Err(format_err!("Error while initializing accounts: {}", e)), - }, + Err(e) => Err(format_err!("Error while initializing accounts: {}", e)), } });