rust: Handle passing Rust Vecs to Swift where len != capacity

Fixes zcash/ZcashLightClientKit#177.
This commit is contained in:
Jack Grigg 2020-10-08 18:43:41 +01:00 committed by Francisco Gindre
parent 0ebb61dd0f
commit 344699a6b2
3 changed files with 28 additions and 17 deletions

View File

@ -76,7 +76,8 @@ class ZcashRustBackend: ZcashRustBackendWelding {
static func initAccountsTable(dbData: URL, seed: [UInt8], accounts: Int32) -> [String]? {
let dbData = dbData.osStr()
let extsksCStr = zcashlc_init_accounts_table(dbData.0, dbData.1, seed, UInt(seed.count), accounts)
let capacity = UInt(0);
let extsksCStr = zcashlc_init_accounts_table(dbData.0, dbData.1, seed, UInt(seed.count), accounts, &capacity)
if extsksCStr == nil {
return nil
}
@ -85,7 +86,7 @@ class ZcashRustBackend: ZcashRustBackendWelding {
guard let str = cStr else { return nil }
return String(cString: str)
})
zcashlc_vec_string_free(extsksCStr, UInt(accounts))
zcashlc_vec_string_free(extsksCStr, UInt(accounts), capacity)
return extsks
}
@ -208,8 +209,8 @@ class ZcashRustBackend: ZcashRustBackendWelding {
}
static func deriveExtendedFullViewingKeys(seed: String, accounts: Int32) throws -> [String]? {
guard let extsksCStr = zcashlc_derive_extended_full_viewing_keys(seed, UInt(seed.lengthOfBytes(using: .utf8)), accounts) else {
let capacity = UInt(0);
guard let extsksCStr = zcashlc_derive_extended_full_viewing_keys(seed, UInt(seed.lengthOfBytes(using: .utf8)), accounts, capacity) else {
if let error = lastError() {
throw error
}
@ -220,12 +221,13 @@ class ZcashRustBackend: ZcashRustBackendWelding {
guard let str = cStr else { return nil }
return String(cString: str)
})
zcashlc_vec_string_free(extsksCStr, UInt(accounts))
zcashlc_vec_string_free(extsksCStr, UInt(accounts), capacity)
return extsks
}
static func deriveExtendedSpendingKeys(seed: String, accounts: Int32) throws -> [String]? {
guard let extsksCStr = zcashlc_derive_extended_spending_keys(seed, UInt(seed.lengthOfBytes(using: .utf8)), accounts) else {
let capacity = UInt(0);
guard let extsksCStr = zcashlc_derive_extended_spending_keys(seed, UInt(seed.lengthOfBytes(using: .utf8)), accounts, &capacity) else {
if let error = lastError() {
throw error
}
@ -236,7 +238,7 @@ class ZcashRustBackend: ZcashRustBackendWelding {
guard let str = cStr else { return nil }
return String(cString: str)
})
zcashlc_vec_string_free(extsksCStr, UInt(accounts))
zcashlc_vec_string_free(extsksCStr, UInt(accounts), capacity)
return extsks
}

View File

@ -42,11 +42,13 @@ char *zcashlc_derive_extended_full_viewing_key(const char *extsk);
char **zcashlc_derive_extended_full_viewing_keys(const uint8_t *seed,
uintptr_t seed_len,
int32_t accounts);
int32_t accounts,
uintptr_t *capacity_ret);
char **zcashlc_derive_extended_spending_keys(const uint8_t *seed,
uintptr_t seed_len,
int32_t accounts);
int32_t accounts,
uintptr_t *capacity_ret);
/**
* Copies the last error message into the provided allocated buffer.
@ -107,7 +109,8 @@ char **zcashlc_init_accounts_table(const uint8_t *db_data,
uintptr_t db_data_len,
const uint8_t *seed,
uintptr_t seed_len,
int32_t accounts);
int32_t accounts,
uintptr_t *capacity_ret);
/**
* Initialises the data database with the given block.
@ -205,4 +208,4 @@ int32_t zcashlc_validate_combined_chain(const uint8_t *db_cache,
/**
* Frees vectors of strings returned by other zcashlc functions.
*/
void zcashlc_vec_string_free(char **v, uintptr_t len);
void zcashlc_vec_string_free(char **v, uintptr_t len, uintptr_t capacity);

View File

@ -115,6 +115,7 @@ pub extern "C" fn zcashlc_init_accounts_table(
seed: *const u8,
seed_len: usize,
accounts: i32,
capacity_ret: *mut usize,
) -> *mut *mut c_char {
let res = catch_panic(|| {
let db_data = Path::new(OsStr::from_bytes(unsafe {
@ -151,7 +152,8 @@ pub extern "C" fn zcashlc_init_accounts_table(
CString::new(encoded).unwrap().into_raw()
})
.collect();
assert!(v.len() == v.capacity());
assert!(v.len() == accounts as usize);
unsafe { *capacity_ret.as_mut().unwrap() = v.capacity() };
let p = v.as_mut_ptr();
std::mem::forget(v);
Ok(p)
@ -197,6 +199,7 @@ pub unsafe extern "C" fn zcashlc_derive_extended_spending_keys(
seed: *const u8,
seed_len: usize,
accounts: i32,
capacity_ret: *mut usize,
) -> *mut *mut c_char {
let res = catch_panic(|| {
let seed = slice::from_raw_parts(seed, seed_len);
@ -219,7 +222,8 @@ pub unsafe extern "C" fn zcashlc_derive_extended_spending_keys(
CString::new(encoded).unwrap().into_raw()
})
.collect();
assert!(v.len() == v.capacity());
assert!(v.len() == accounts as usize);
*capacity_ret.as_mut().unwrap() = v.capacity();
let p = v.as_mut_ptr();
std::mem::forget(v);
Ok(p)
@ -232,6 +236,7 @@ pub unsafe extern "C" fn zcashlc_derive_extended_full_viewing_keys(
seed: *const u8,
seed_len: usize,
accounts: i32,
capacity_ret: *mut usize,
) -> *mut *mut c_char {
let res = catch_panic(|| {
let seed = slice::from_raw_parts(seed, seed_len);
@ -254,7 +259,8 @@ pub unsafe extern "C" fn zcashlc_derive_extended_full_viewing_keys(
CString::new(encoded).unwrap().into_raw()
})
.collect();
assert!(v.len() == v.capacity());
assert!(v.len() == accounts as usize);
*capacity_ret.as_mut().unwrap() = v.capacity();
let p = v.as_mut_ptr();
std::mem::forget(v);
Ok(p)
@ -703,13 +709,13 @@ pub extern "C" fn zcashlc_string_free(s: *mut c_char) {
/// Frees vectors of strings returned by other zcashlc functions.
#[no_mangle]
pub extern "C" fn zcashlc_vec_string_free(v: *mut *mut c_char, len: usize) {
pub extern "C" fn zcashlc_vec_string_free(v: *mut *mut c_char, len: usize, capacity: usize) {
unsafe {
if v.is_null() {
return;
}
// All Vecs created by other functions MUST have length == capacity.
let v = Vec::from_raw_parts(v, len, len);
assert!(len <= capacity);
let v = Vec::from_raw_parts(v, len, capacity);
v.into_iter().map(|s| CString::from_raw(s)).for_each(drop);
};
}