diff --git a/Cargo.lock b/Cargo.lock index ae32cd463..731c0a832 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2649,6 +2649,7 @@ dependencies = [ name = "solana-sdk" version = "0.16.0" dependencies = [ + "assert_matches 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "bincode 1.1.4 (registry+https://github.com/rust-lang/crates.io-index)", "bs58 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index e0dc886f7..33bbe5587 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -9,6 +9,7 @@ license = "Apache-2.0" edition = "2018" [dependencies] +assert_matches = "1.3.0" bincode = "1.1.4" bs58 = "0.2.0" byteorder = "1.2.1" diff --git a/sdk/src/short_vec.rs b/sdk/src/short_vec.rs index 220be393f..3d8103f16 100644 --- a/sdk/src/short_vec.rs +++ b/sdk/src/short_vec.rs @@ -1,18 +1,18 @@ use serde::de::{self, Deserializer, SeqAccess, Visitor}; -use serde::ser::{SerializeTuple, Serializer}; +use serde::ser::{self, SerializeTuple, Serializer}; use serde::{Deserialize, Serialize}; use std::fmt; use std::marker::PhantomData; use std::mem::size_of; -/// Same as usize, but serialized with 1 to 9 bytes. If the value is above +/// Same as u16, but serialized with 1 to 3 bytes. If the value is above /// 0x7f, the top bit is set and the remaining value is stored in the next -/// bytes. Each byte follows the same pattern until the 9th byte. The 9th +/// bytes. Each byte follows the same pattern until the 3rd byte. The 3rd /// byte, if needed, uses all 8 bits to store the last byte of the original /// value. -pub struct ShortUsize(pub usize); +pub struct ShortU16(pub u16); -impl Serialize for ShortUsize { +impl Serialize for ShortU16 { fn serialize(&self, serializer: S) -> Result where S: Serializer, @@ -40,13 +40,13 @@ impl Serialize for ShortUsize { struct ShortLenVisitor; impl<'de> Visitor<'de> for ShortLenVisitor { - type Value = ShortUsize; + type Value = ShortU16; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a multi-byte length") } - fn visit_seq(self, mut seq: A) -> Result + fn visit_seq(self, mut seq: A) -> Result where A: SeqAccess<'de>, { @@ -64,21 +64,21 @@ impl<'de> Visitor<'de> for ShortLenVisitor { break; } - if size > size_of::() + 1 { + if size > size_of::() + 1 { return Err(de::Error::invalid_length(size, &self)); } } - Ok(ShortUsize(len)) + Ok(ShortU16(len as u16)) } } -impl<'de> Deserialize<'de> for ShortUsize { - fn deserialize(deserializer: D) -> Result +impl<'de> Deserialize<'de> for ShortU16 { + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { - deserializer.deserialize_tuple(9, ShortLenVisitor) + deserializer.deserialize_tuple(3, ShortLenVisitor) } } @@ -95,7 +95,11 @@ pub fn serialize( // generate an open bracket. let mut seq = serializer.serialize_tuple(1)?; - let short_len = ShortUsize(elements.len()); + let len = elements.len(); + if len > std::u16::MAX as usize { + return Err(ser::Error::custom("length larger than u16")); + } + let short_len = ShortU16(len as u16); seq.serialize_element(&short_len)?; for element in elements { @@ -122,10 +126,10 @@ where where A: SeqAccess<'de>, { - let short_len: ShortUsize = seq + let short_len: ShortU16 = seq .next_element()? .ok_or_else(|| de::Error::invalid_length(0, &self))?; - let len = short_len.0; + let len = short_len.0 as usize; let mut result = Vec::with_capacity(len); for i in 0..len { @@ -172,28 +176,29 @@ impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec { } } -/// Return the serialized length. -pub fn encode_len(len: usize) -> Vec { - bincode::serialize(&ShortUsize(len)).unwrap() -} - /// Return the decoded value and how many bytes it consumed. pub fn decode_len(bytes: &[u8]) -> (usize, usize) { - let short_len: ShortUsize = bincode::deserialize(bytes).unwrap(); + let short_len: ShortU16 = bincode::deserialize(bytes).unwrap(); let num_bytes = bincode::serialized_size(&short_len).unwrap() as usize; - (short_len.0, num_bytes) + (short_len.0 as usize, num_bytes) } #[cfg(test)] mod tests { use super::*; + use assert_matches::assert_matches; use bincode::{deserialize, serialize}; - fn assert_len_encoding(len: usize, bytes: &[u8]) { + /// Return the serialized length. + fn encode_len(len: u16) -> Vec { + bincode::serialize(&ShortU16(len)).unwrap() + } + + fn assert_len_encoding(len: u16, bytes: &[u8]) { assert_eq!(encode_len(len), bytes, "unexpected usize encoding"); assert_eq!( decode_len(bytes), - (len, bytes.len()), + (len as usize, bytes.len()), "unexpected usize decoding" ); } @@ -206,8 +211,7 @@ mod tests { assert_len_encoding(0xff, &[0xff, 0x01]); assert_len_encoding(0x100, &[0x80, 0x02]); assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]); - assert_len_encoding(0x200000, &[0x80, 0x80, 0x80, 0x01]); - assert_len_encoding(0x7ffffffff, &[0xff, 0xff, 0xff, 0xff, 0x7f]); + assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]); } #[test] @@ -226,10 +230,19 @@ mod tests { assert_eq!(vec.0, vec1.0); } + #[test] + fn test_short_vec_u8_too_long() { + let vec = ShortVec(vec![4u8; std::u16::MAX as usize]); + assert_matches!(serialize(&vec), Ok(_)); + + let vec = ShortVec(vec![4u8; std::u16::MAX as usize + 1]); + assert_matches!(serialize(&vec), Err(_)); + } + #[test] fn test_short_vec_json() { - let vec = ShortVec(vec![0u8]); + let vec = ShortVec(vec![0, 1, 2]); let s = serde_json::to_string(&vec).unwrap(); - assert_eq!(s, "[[1],0]"); + assert_eq!(s, "[[3],0,1,2]"); } }