Limit short_vec length to u16, usize is overkill for our usage (#4588)

This commit is contained in:
Michael Vines 2019-06-06 20:18:41 -07:00 committed by GitHub
parent fd9fd43e83
commit 492cc93850
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 28 deletions

1
Cargo.lock generated
View File

@ -2649,6 +2649,7 @@ dependencies = [
name = "solana-sdk" name = "solana-sdk"
version = "0.16.0" version = "0.16.0"
dependencies = [ dependencies = [
"assert_matches 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"bincode 1.1.4 (registry+https://github.com/rust-lang/crates.io-index)", "bincode 1.1.4 (registry+https://github.com/rust-lang/crates.io-index)",
"bs58 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", "bs58 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
"byteorder 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@ -9,6 +9,7 @@ license = "Apache-2.0"
edition = "2018" edition = "2018"
[dependencies] [dependencies]
assert_matches = "1.3.0"
bincode = "1.1.4" bincode = "1.1.4"
bs58 = "0.2.0" bs58 = "0.2.0"
byteorder = "1.2.1" byteorder = "1.2.1"

View File

@ -1,18 +1,18 @@
use serde::de::{self, Deserializer, SeqAccess, Visitor}; use serde::de::{self, Deserializer, SeqAccess, Visitor};
use serde::ser::{SerializeTuple, Serializer}; use serde::ser::{self, SerializeTuple, Serializer};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem::size_of; use std::mem::size_of;
/// Same as usize, but serialized with 1 to 9 bytes. If the value is above /// Same as u16, but serialized with 1 to 3 bytes. If the value is above
/// 0x7f, the top bit is set and the remaining value is stored in the next /// 0x7f, the top bit is set and the remaining value is stored in the next
/// bytes. Each byte follows the same pattern until the 9th byte. The 9th /// bytes. Each byte follows the same pattern until the 3rd byte. The 3rd
/// byte, if needed, uses all 8 bits to store the last byte of the original /// byte, if needed, uses all 8 bits to store the last byte of the original
/// value. /// value.
pub struct ShortUsize(pub usize); pub struct ShortU16(pub u16);
impl Serialize for ShortUsize { impl Serialize for ShortU16 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
S: Serializer, S: Serializer,
@ -40,13 +40,13 @@ impl Serialize for ShortUsize {
struct ShortLenVisitor; struct ShortLenVisitor;
impl<'de> Visitor<'de> for ShortLenVisitor { impl<'de> Visitor<'de> for ShortLenVisitor {
type Value = ShortUsize; type Value = ShortU16;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a multi-byte length") formatter.write_str("a multi-byte length")
} }
fn visit_seq<A>(self, mut seq: A) -> Result<ShortUsize, A::Error> fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
where where
A: SeqAccess<'de>, A: SeqAccess<'de>,
{ {
@ -64,21 +64,21 @@ impl<'de> Visitor<'de> for ShortLenVisitor {
break; break;
} }
if size > size_of::<usize>() + 1 { if size > size_of::<u16>() + 1 {
return Err(de::Error::invalid_length(size, &self)); return Err(de::Error::invalid_length(size, &self));
} }
} }
Ok(ShortUsize(len)) Ok(ShortU16(len as u16))
} }
} }
impl<'de> Deserialize<'de> for ShortUsize { impl<'de> Deserialize<'de> for ShortU16 {
fn deserialize<D>(deserializer: D) -> Result<ShortUsize, D::Error> fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
deserializer.deserialize_tuple(9, ShortLenVisitor) deserializer.deserialize_tuple(3, ShortLenVisitor)
} }
} }
@ -95,7 +95,11 @@ pub fn serialize<S: Serializer, T: Serialize>(
// generate an open bracket. // generate an open bracket.
let mut seq = serializer.serialize_tuple(1)?; let mut seq = serializer.serialize_tuple(1)?;
let short_len = ShortUsize(elements.len()); let len = elements.len();
if len > std::u16::MAX as usize {
return Err(ser::Error::custom("length larger than u16"));
}
let short_len = ShortU16(len as u16);
seq.serialize_element(&short_len)?; seq.serialize_element(&short_len)?;
for element in elements { for element in elements {
@ -122,10 +126,10 @@ where
where where
A: SeqAccess<'de>, A: SeqAccess<'de>,
{ {
let short_len: ShortUsize = seq let short_len: ShortU16 = seq
.next_element()? .next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?; .ok_or_else(|| de::Error::invalid_length(0, &self))?;
let len = short_len.0; let len = short_len.0 as usize;
let mut result = Vec::with_capacity(len); let mut result = Vec::with_capacity(len);
for i in 0..len { for i in 0..len {
@ -172,28 +176,29 @@ impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
} }
} }
/// Return the serialized length.
pub fn encode_len(len: usize) -> Vec<u8> {
bincode::serialize(&ShortUsize(len)).unwrap()
}
/// Return the decoded value and how many bytes it consumed. /// Return the decoded value and how many bytes it consumed.
pub fn decode_len(bytes: &[u8]) -> (usize, usize) { pub fn decode_len(bytes: &[u8]) -> (usize, usize) {
let short_len: ShortUsize = bincode::deserialize(bytes).unwrap(); let short_len: ShortU16 = bincode::deserialize(bytes).unwrap();
let num_bytes = bincode::serialized_size(&short_len).unwrap() as usize; let num_bytes = bincode::serialized_size(&short_len).unwrap() as usize;
(short_len.0, num_bytes) (short_len.0 as usize, num_bytes)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use assert_matches::assert_matches;
use bincode::{deserialize, serialize}; use bincode::{deserialize, serialize};
fn assert_len_encoding(len: usize, bytes: &[u8]) { /// Return the serialized length.
fn encode_len(len: u16) -> Vec<u8> {
bincode::serialize(&ShortU16(len)).unwrap()
}
fn assert_len_encoding(len: u16, bytes: &[u8]) {
assert_eq!(encode_len(len), bytes, "unexpected usize encoding"); assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
assert_eq!( assert_eq!(
decode_len(bytes), decode_len(bytes),
(len, bytes.len()), (len as usize, bytes.len()),
"unexpected usize decoding" "unexpected usize decoding"
); );
} }
@ -206,8 +211,7 @@ mod tests {
assert_len_encoding(0xff, &[0xff, 0x01]); assert_len_encoding(0xff, &[0xff, 0x01]);
assert_len_encoding(0x100, &[0x80, 0x02]); assert_len_encoding(0x100, &[0x80, 0x02]);
assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]); assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
assert_len_encoding(0x200000, &[0x80, 0x80, 0x80, 0x01]); assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
assert_len_encoding(0x7ffffffff, &[0xff, 0xff, 0xff, 0xff, 0x7f]);
} }
#[test] #[test]
@ -226,10 +230,19 @@ mod tests {
assert_eq!(vec.0, vec1.0); assert_eq!(vec.0, vec1.0);
} }
#[test]
fn test_short_vec_u8_too_long() {
let vec = ShortVec(vec![4u8; std::u16::MAX as usize]);
assert_matches!(serialize(&vec), Ok(_));
let vec = ShortVec(vec![4u8; std::u16::MAX as usize + 1]);
assert_matches!(serialize(&vec), Err(_));
}
#[test] #[test]
fn test_short_vec_json() { fn test_short_vec_json() {
let vec = ShortVec(vec![0u8]); let vec = ShortVec(vec![0, 1, 2]);
let s = serde_json::to_string(&vec).unwrap(); let s = serde_json::to_string(&vec).unwrap();
assert_eq!(s, "[[1],0]"); assert_eq!(s, "[[3],0,1,2]");
} }
} }