diff --git a/sdk/program/src/borsh.rs b/sdk/program/src/borsh.rs index e85d685040..0137de6f2a 100644 --- a/sdk/program/src/borsh.rs +++ b/sdk/program/src/borsh.rs @@ -2,9 +2,9 @@ //! Borsh utils use { borsh::{ - maybestd::io::Error, + maybestd::io::{Error, Write}, schema::{BorshSchema, Declaration, Definition, Fields}, - BorshDeserialize, + BorshDeserialize, BorshSerialize, }, std::collections::HashMap, }; @@ -56,6 +56,9 @@ fn get_declaration_packed_len( } /// Get the worst-case packed length for the given BorshSchema +/// +/// Note: due to the serializer currently used by Borsh, this function cannot +/// be used on-chain in the Solana BPF execution environment. pub fn get_packed_len() -> usize { let schema_container = S::schema_container(); get_declaration_packed_len(&schema_container.declaration, &schema_container.definitions) @@ -76,15 +79,45 @@ pub fn try_from_slice_unchecked(data: &[u8]) -> Result Result { + let amount = data.len(); + self.count += amount; + Ok(amount) + } + + fn flush(&mut self) -> Result<(), Error> { + Ok(()) + } +} + +/// Get the packed length for the serialized form of this object instance. +/// +/// Useful when working with instances of types that contain a variable-length +/// sequence, such as a Vec or HashMap. Since it is impossible to know the packed +/// length only from the type's schema, this can be used when an instance already +/// exists, to figure out how much space to allocate in an account. +pub fn get_instance_packed_len(instance: &T) -> Result { + let mut counter = WriteCounter::default(); + instance.serialize(&mut counter)?; + Ok(counter.count) +} + #[cfg(test)] mod tests { use { super::*, borsh::{maybestd::io::ErrorKind, BorshSchema, BorshSerialize}, - std::mem::size_of, + std::{collections::HashMap, mem::size_of}, }; - #[derive(BorshSerialize, BorshDeserialize, BorshSchema)] + #[derive(PartialEq, Clone, Debug, BorshSerialize, BorshDeserialize, BorshSchema)] enum TestEnum { NoValue, Value(u32), @@ -96,7 +129,14 @@ mod tests { }, } - #[derive(BorshSerialize, BorshDeserialize, BorshSchema)] + // for test simplicity + impl Default for TestEnum { + fn default() -> Self { + Self::NoValue + } + } + + #[derive(Default, BorshSerialize, BorshDeserialize, BorshSchema)] struct TestStruct { pub array: [u64; 16], pub number_u128: u128, @@ -159,4 +199,101 @@ mod tests { + get_packed_len::() ); } + + #[test] + fn instance_packed_len_matches_packed_len() { + let enumeration = TestEnum::StructValue { + number: u64::MAX, + array: [255; 8], + }; + assert_eq!( + get_packed_len::(), + get_instance_packed_len(&enumeration).unwrap(), + ); + let test_struct = TestStruct { + enumeration, + ..TestStruct::default() + }; + assert_eq!( + get_packed_len::(), + get_instance_packed_len(&test_struct).unwrap(), + ); + assert_eq!( + get_packed_len::(), + get_instance_packed_len(&0u8).unwrap(), + ); + assert_eq!( + get_packed_len::(), + get_instance_packed_len(&0u16).unwrap(), + ); + assert_eq!( + get_packed_len::(), + get_instance_packed_len(&0u32).unwrap(), + ); + assert_eq!( + get_packed_len::(), + get_instance_packed_len(&0u64).unwrap(), + ); + assert_eq!( + get_packed_len::(), + get_instance_packed_len(&0u128).unwrap(), + ); + assert_eq!( + get_packed_len::<[u8; 10]>(), + get_instance_packed_len(&[0u8; 10]).unwrap(), + ); + assert_eq!( + get_packed_len::<(i8, i16, i32, i64, i128)>(), + get_instance_packed_len(&(i8::MAX, i16::MAX, i32::MAX, i64::MAX, i128::MAX)).unwrap(), + ); + } + + #[test] + fn instance_packed_len_with_vec() { + let data = vec![ + Child { data: [0u8; 64] }, + Child { data: [1u8; 64] }, + Child { data: [2u8; 64] }, + Child { data: [3u8; 64] }, + Child { data: [4u8; 64] }, + Child { data: [5u8; 64] }, + ]; + let parent = Parent { data }; + assert_eq!( + get_instance_packed_len(&parent).unwrap(), + 4 + parent.data.len() * get_packed_len::() + ); + } + + #[derive(Debug, PartialEq, BorshSerialize, BorshDeserialize, BorshSchema)] + struct StructWithHashMap { + data: HashMap, + } + + #[test] + fn instance_packed_len_with_varying_sizes_in_hashmap() { + let mut data = HashMap::new(); + let string1 = "the first string, it's actually really really long".to_string(); + let enum1 = TestEnum::NoValue; + let string2 = "second string, shorter".to_string(); + let enum2 = TestEnum::Value(u32::MAX); + let string3 = "third".to_string(); + let enum3 = TestEnum::StructValue { + number: 0, + array: [0; 8], + }; + data.insert(string1.clone(), enum1.clone()); + data.insert(string2.clone(), enum2.clone()); + data.insert(string3.clone(), enum3.clone()); + let instance = StructWithHashMap { data }; + assert_eq!( + get_instance_packed_len(&instance).unwrap(), + 4 + get_instance_packed_len(&string1).unwrap() + + get_instance_packed_len(&enum1).unwrap() + + get_instance_packed_len(&string2).unwrap() + + get_instance_packed_len(&enum2).unwrap() + + get_instance_packed_len(&string3).unwrap() + + get_instance_packed_len(&enum3).unwrap() + ); + } }