From 6bc858a8889fcc982027a4bf70e00039768bb6af Mon Sep 17 00:00:00 2001 From: Kristofer Peterson Date: Sat, 6 Mar 2021 22:29:12 +0000 Subject: [PATCH] Refactored ShortU16Visitor::visit_seq() to reject overflows, extra leading zeros and ensure one-to-one encoding. --- perf/src/sigverify.rs | 9 +-- sdk/program/src/short_vec.rs | 134 +++++++++++++++++++++++------------ 2 files changed, 95 insertions(+), 48 deletions(-) diff --git a/perf/src/sigverify.rs b/perf/src/sigverify.rs index 7473a0e1d..356d217cb 100644 --- a/perf/src/sigverify.rs +++ b/perf/src/sigverify.rs @@ -13,7 +13,7 @@ use solana_metrics::inc_new_counter_debug; use solana_rayon_threadlimit::get_thread_count; use solana_sdk::message::MESSAGE_HEADER_LENGTH; use solana_sdk::pubkey::Pubkey; -use solana_sdk::short_vec::decode_len; +use solana_sdk::short_vec::decode_shortu16_len; use solana_sdk::signature::Signature; #[cfg(test)] use solana_sdk::transaction::Transaction; @@ -163,7 +163,7 @@ fn do_get_packet_offsets( // read the length of Transaction.signatures (serialized with short_vec) let (sig_len_untrusted, sig_size) = - decode_len(&packet.data).map_err(|_| PacketError::InvalidShortVec)?; + decode_shortu16_len(&packet.data).map_err(|_| PacketError::InvalidShortVec)?; // Using msg_start_offset which is based on sig_len_untrusted introduces uncertainty. // Ultimately, the actual sigverify will determine the uncertainty. @@ -203,8 +203,9 @@ fn do_get_packet_offsets( } // read the length of Message.account_keys (serialized with short_vec) - let (pubkey_len, pubkey_len_size) = decode_len(&packet.data[message_account_keys_len_offset..]) - .map_err(|_| PacketError::InvalidShortVec)?; + let (pubkey_len, pubkey_len_size) = + decode_shortu16_len(&packet.data[message_account_keys_len_offset..]) + .map_err(|_| PacketError::InvalidShortVec)?; let pubkey_start = message_account_keys_len_offset .checked_add(pubkey_len_size) diff --git a/sdk/program/src/short_vec.rs b/sdk/program/src/short_vec.rs index 2c09b8de3..67409144b 100644 --- a/sdk/program/src/short_vec.rs +++ b/sdk/program/src/short_vec.rs @@ -4,7 +4,7 @@ use serde::{ ser::{self, SerializeTuple, Serializer}, {Deserialize, Serialize}, }; -use std::{fmt, marker::PhantomData, mem::size_of}; +use std::{convert::TryFrom, fmt, marker::PhantomData}; /// Same as u16, but serialized with 1 to 3 bytes. If the value is above /// 0x7f, the top bit is set and the remaining value is stored in the next @@ -39,23 +39,77 @@ impl Serialize for ShortU16 { } } -enum VisitResult { - Done(usize, usize), - More(usize, usize), - Err, +enum VisitStatus { + Done(u16), + More(u16), } -fn visit_byte(elem: u8, val: usize, size: usize) -> VisitResult { - let val = val | (elem as usize & 0x7f) << (size * 7); - let size = size + 1; - let more = elem as usize & 0x80 == 0x80; +#[derive(Debug)] +enum VisitError { + TooLong(usize), + TooShort(usize), + Overflow(u32), + Alias, + ByteThreeContinues, +} - if size > size_of::() + 1 { - VisitResult::Err - } else if more { - VisitResult::More(val, size) +impl VisitError { + fn into_de_error<'de, A>(self) -> A::Error + where + A: SeqAccess<'de>, + { + match self { + VisitError::TooLong(len) => { + de::Error::invalid_length(len as usize, &"three or fewer bytes") + } + VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"), + VisitError::Overflow(val) => de::Error::invalid_value( + de::Unexpected::Unsigned(val as u64), + &"a value in the range [0, 65535]", + ), + VisitError::Alias => de::Error::invalid_value( + de::Unexpected::Other("alias encoding"), + &"strict form encoding", + ), + VisitError::ByteThreeContinues => de::Error::invalid_value( + de::Unexpected::Other("continue signal on byte-three"), + &"a terminal signal on or before byte-three", + ), + } + } +} + +type VisitResult = Result; + +const MAX_ENCODING_LENGTH: usize = 3; +fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult { + if elem == 0 && nth_byte != 0 { + return Err(VisitError::Alias); + } + + let val = u32::from(val); + let elem = u32::from(elem); + let elem_val = elem & 0x7f; + let elem_done = (elem & 0x80) == 0; + + if nth_byte >= MAX_ENCODING_LENGTH { + return Err(VisitError::TooLong(nth_byte.saturating_add(1))); + } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done { + return Err(VisitError::ByteThreeContinues); + } + + let shift = u32::try_from(nth_byte) + .unwrap_or(u32::MAX) + .saturating_mul(7); + let elem_val = elem_val.checked_shl(shift).unwrap_or(u32::MAX); + + let new_val = val | elem_val; + let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?; + + if elem_done { + Ok(VisitStatus::Done(val)) } else { - VisitResult::Done(val, size) + Ok(VisitStatus::More(val)) } } @@ -72,27 +126,22 @@ impl<'de> Visitor<'de> for ShortU16Visitor { where A: SeqAccess<'de>, { - let mut val: usize = 0; - let mut size: usize = 0; - loop { - let elem: u8 = seq - .next_element()? - .ok_or_else(|| de::Error::invalid_length(size, &self))?; - - match visit_byte(elem, val, size) { - VisitResult::Done(l, _) => { - val = l; - break; - } - VisitResult::More(l, s) => { - val = l; - size = s; - } - VisitResult::Err => return Err(de::Error::invalid_length(size + 1, &self)), + // Decodes an unsigned 16 bit integer one-to-one encoded as follows: + // 1 byte : 0xxxxxxx => 00000000 0xxxxxxx : 0 - 127 + // 2 bytes : 1xxxxxxx 0yyyyyyy => 00yyyyyy yxxxxxxx : 128 - 16,383 + // 3 bytes : 1xxxxxxx 1yyyyyyy 000000zz => zzyyyyyy yxxxxxxx : 16,384 - 65,535 + let mut val: u16 = 0; + for nth_byte in 0..MAX_ENCODING_LENGTH { + let elem: u8 = seq.next_element()?.ok_or_else(|| { + VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::() + })?; + match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::())? { + VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)), + VisitStatus::More(new_val) => val = new_val, } } - Ok(ShortU16(val as u16)) + Err(VisitError::ByteThreeContinues.into_de_error::()) } } @@ -201,17 +250,14 @@ impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec { /// Return the decoded value and how many bytes it consumed. #[allow(clippy::result_unit_err)] -pub fn decode_len(bytes: &[u8]) -> Result<(usize, usize), ()> { - let mut len = 0; - let mut size = 0; - for byte in bytes.iter() { - match visit_byte(*byte, len, size) { - VisitResult::More(l, s) => { - len = l; - size = s; +pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> { + let mut val = 0; + for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() { + match visit_byte(*byte, val, nth_byte).map_err(|_| ())? { + VisitStatus::More(new_val) => val = new_val, + VisitStatus::Done(new_val) => { + return Ok((usize::from(new_val), nth_byte.saturating_add(1))); } - VisitResult::Done(len, size) => return Ok((len, size)), - VisitResult::Err => return Err(()), } } Err(()) @@ -231,8 +277,8 @@ mod tests { fn assert_len_encoding(len: u16, bytes: &[u8]) { assert_eq!(encode_len(len), bytes, "unexpected usize encoding"); assert_eq!( - decode_len(bytes).unwrap(), - (len as usize, bytes.len()), + decode_shortu16_len(bytes).unwrap(), + (usize::from(len), bytes.len()), "unexpected usize decoding" ); }