Refactored ShortU16Visitor::visit_seq() to reject overflows, extra leading zeros and ensure one-to-one encoding.
This commit is contained in:
parent
9153cb9237
commit
6bc858a888
|
@ -13,7 +13,7 @@ use solana_metrics::inc_new_counter_debug;
|
|||
use solana_rayon_threadlimit::get_thread_count;
|
||||
use solana_sdk::message::MESSAGE_HEADER_LENGTH;
|
||||
use solana_sdk::pubkey::Pubkey;
|
||||
use solana_sdk::short_vec::decode_len;
|
||||
use solana_sdk::short_vec::decode_shortu16_len;
|
||||
use solana_sdk::signature::Signature;
|
||||
#[cfg(test)]
|
||||
use solana_sdk::transaction::Transaction;
|
||||
|
@ -163,7 +163,7 @@ fn do_get_packet_offsets(
|
|||
|
||||
// read the length of Transaction.signatures (serialized with short_vec)
|
||||
let (sig_len_untrusted, sig_size) =
|
||||
decode_len(&packet.data).map_err(|_| PacketError::InvalidShortVec)?;
|
||||
decode_shortu16_len(&packet.data).map_err(|_| PacketError::InvalidShortVec)?;
|
||||
|
||||
// Using msg_start_offset which is based on sig_len_untrusted introduces uncertainty.
|
||||
// Ultimately, the actual sigverify will determine the uncertainty.
|
||||
|
@ -203,7 +203,8 @@ fn do_get_packet_offsets(
|
|||
}
|
||||
|
||||
// read the length of Message.account_keys (serialized with short_vec)
|
||||
let (pubkey_len, pubkey_len_size) = decode_len(&packet.data[message_account_keys_len_offset..])
|
||||
let (pubkey_len, pubkey_len_size) =
|
||||
decode_shortu16_len(&packet.data[message_account_keys_len_offset..])
|
||||
.map_err(|_| PacketError::InvalidShortVec)?;
|
||||
|
||||
let pubkey_start = message_account_keys_len_offset
|
||||
|
|
|
@ -4,7 +4,7 @@ use serde::{
|
|||
ser::{self, SerializeTuple, Serializer},
|
||||
{Deserialize, Serialize},
|
||||
};
|
||||
use std::{fmt, marker::PhantomData, mem::size_of};
|
||||
use std::{convert::TryFrom, fmt, marker::PhantomData};
|
||||
|
||||
/// Same as u16, but serialized with 1 to 3 bytes. If the value is above
|
||||
/// 0x7f, the top bit is set and the remaining value is stored in the next
|
||||
|
@ -39,23 +39,77 @@ impl Serialize for ShortU16 {
|
|||
}
|
||||
}
|
||||
|
||||
enum VisitResult {
|
||||
Done(usize, usize),
|
||||
More(usize, usize),
|
||||
Err,
|
||||
enum VisitStatus {
|
||||
Done(u16),
|
||||
More(u16),
|
||||
}
|
||||
|
||||
fn visit_byte(elem: u8, val: usize, size: usize) -> VisitResult {
|
||||
let val = val | (elem as usize & 0x7f) << (size * 7);
|
||||
let size = size + 1;
|
||||
let more = elem as usize & 0x80 == 0x80;
|
||||
#[derive(Debug)]
|
||||
enum VisitError {
|
||||
TooLong(usize),
|
||||
TooShort(usize),
|
||||
Overflow(u32),
|
||||
Alias,
|
||||
ByteThreeContinues,
|
||||
}
|
||||
|
||||
if size > size_of::<u16>() + 1 {
|
||||
VisitResult::Err
|
||||
} else if more {
|
||||
VisitResult::More(val, size)
|
||||
impl VisitError {
|
||||
fn into_de_error<'de, A>(self) -> A::Error
|
||||
where
|
||||
A: SeqAccess<'de>,
|
||||
{
|
||||
match self {
|
||||
VisitError::TooLong(len) => {
|
||||
de::Error::invalid_length(len as usize, &"three or fewer bytes")
|
||||
}
|
||||
VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
|
||||
VisitError::Overflow(val) => de::Error::invalid_value(
|
||||
de::Unexpected::Unsigned(val as u64),
|
||||
&"a value in the range [0, 65535]",
|
||||
),
|
||||
VisitError::Alias => de::Error::invalid_value(
|
||||
de::Unexpected::Other("alias encoding"),
|
||||
&"strict form encoding",
|
||||
),
|
||||
VisitError::ByteThreeContinues => de::Error::invalid_value(
|
||||
de::Unexpected::Other("continue signal on byte-three"),
|
||||
&"a terminal signal on or before byte-three",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type VisitResult = Result<VisitStatus, VisitError>;
|
||||
|
||||
const MAX_ENCODING_LENGTH: usize = 3;
|
||||
fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
|
||||
if elem == 0 && nth_byte != 0 {
|
||||
return Err(VisitError::Alias);
|
||||
}
|
||||
|
||||
let val = u32::from(val);
|
||||
let elem = u32::from(elem);
|
||||
let elem_val = elem & 0x7f;
|
||||
let elem_done = (elem & 0x80) == 0;
|
||||
|
||||
if nth_byte >= MAX_ENCODING_LENGTH {
|
||||
return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
|
||||
} else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
|
||||
return Err(VisitError::ByteThreeContinues);
|
||||
}
|
||||
|
||||
let shift = u32::try_from(nth_byte)
|
||||
.unwrap_or(u32::MAX)
|
||||
.saturating_mul(7);
|
||||
let elem_val = elem_val.checked_shl(shift).unwrap_or(u32::MAX);
|
||||
|
||||
let new_val = val | elem_val;
|
||||
let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
|
||||
|
||||
if elem_done {
|
||||
Ok(VisitStatus::Done(val))
|
||||
} else {
|
||||
VisitResult::Done(val, size)
|
||||
Ok(VisitStatus::More(val))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,27 +126,22 @@ impl<'de> Visitor<'de> for ShortU16Visitor {
|
|||
where
|
||||
A: SeqAccess<'de>,
|
||||
{
|
||||
let mut val: usize = 0;
|
||||
let mut size: usize = 0;
|
||||
loop {
|
||||
let elem: u8 = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| de::Error::invalid_length(size, &self))?;
|
||||
|
||||
match visit_byte(elem, val, size) {
|
||||
VisitResult::Done(l, _) => {
|
||||
val = l;
|
||||
break;
|
||||
}
|
||||
VisitResult::More(l, s) => {
|
||||
val = l;
|
||||
size = s;
|
||||
}
|
||||
VisitResult::Err => return Err(de::Error::invalid_length(size + 1, &self)),
|
||||
// Decodes an unsigned 16 bit integer one-to-one encoded as follows:
|
||||
// 1 byte : 0xxxxxxx => 00000000 0xxxxxxx : 0 - 127
|
||||
// 2 bytes : 1xxxxxxx 0yyyyyyy => 00yyyyyy yxxxxxxx : 128 - 16,383
|
||||
// 3 bytes : 1xxxxxxx 1yyyyyyy 000000zz => zzyyyyyy yxxxxxxx : 16,384 - 65,535
|
||||
let mut val: u16 = 0;
|
||||
for nth_byte in 0..MAX_ENCODING_LENGTH {
|
||||
let elem: u8 = seq.next_element()?.ok_or_else(|| {
|
||||
VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
|
||||
})?;
|
||||
match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
|
||||
VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
|
||||
VisitStatus::More(new_val) => val = new_val,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ShortU16(val as u16))
|
||||
Err(VisitError::ByteThreeContinues.into_de_error::<A>())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -201,17 +250,14 @@ impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
|
|||
|
||||
/// Return the decoded value and how many bytes it consumed.
|
||||
#[allow(clippy::result_unit_err)]
|
||||
pub fn decode_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
|
||||
let mut len = 0;
|
||||
let mut size = 0;
|
||||
for byte in bytes.iter() {
|
||||
match visit_byte(*byte, len, size) {
|
||||
VisitResult::More(l, s) => {
|
||||
len = l;
|
||||
size = s;
|
||||
pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
|
||||
let mut val = 0;
|
||||
for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
|
||||
match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
|
||||
VisitStatus::More(new_val) => val = new_val,
|
||||
VisitStatus::Done(new_val) => {
|
||||
return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
|
||||
}
|
||||
VisitResult::Done(len, size) => return Ok((len, size)),
|
||||
VisitResult::Err => return Err(()),
|
||||
}
|
||||
}
|
||||
Err(())
|
||||
|
@ -231,8 +277,8 @@ mod tests {
|
|||
fn assert_len_encoding(len: u16, bytes: &[u8]) {
|
||||
assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
|
||||
assert_eq!(
|
||||
decode_len(bytes).unwrap(),
|
||||
(len as usize, bytes.len()),
|
||||
decode_shortu16_len(bytes).unwrap(),
|
||||
(usize::from(len), bytes.len()),
|
||||
"unexpected usize decoding"
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue