From c7a57672d7fbc2ab75844b01b60b7f63b052e3a1 Mon Sep 17 00:00:00 2001 From: Brooks Date: Thu, 14 Dec 2023 13:25:25 -0500 Subject: [PATCH] Adds read/write/get_pod() fns to tiered storage (#34415) --- accounts-db/src/tiered_storage/byte_block.rs | 70 ++++++++++++++++---- accounts-db/src/tiered_storage/file.rs | 56 ++++++++++++++-- accounts-db/src/tiered_storage/footer.rs | 32 +++++---- accounts-db/src/tiered_storage/hot.rs | 21 +++--- accounts-db/src/tiered_storage/index.rs | 10 +-- accounts-db/src/tiered_storage/mmap_utils.rs | 26 +++++++- accounts-db/src/tiered_storage/owners.rs | 6 +- 7 files changed, 170 insertions(+), 51 deletions(-) diff --git a/accounts-db/src/tiered_storage/byte_block.rs b/accounts-db/src/tiered_storage/byte_block.rs index 6ddd66c28..28e32815b 100644 --- a/accounts-db/src/tiered_storage/byte_block.rs +++ b/accounts-db/src/tiered_storage/byte_block.rs @@ -53,11 +53,31 @@ impl ByteBlockWriter { self.len } + /// Write plain ol' data to the internal buffer of the ByteBlockWriter instance + /// + /// Prefer this over `write_type()`, as it prevents some undefined behavior. + pub fn write_pod(&mut self, value: &T) -> IoResult { + // SAFETY: Since T is NoUninit, it does not contain any uninitialized bytes. + unsafe { self.write_type(value) } + } + /// Write the specified typed instance to the internal buffer of /// the ByteBlockWriter instance. - pub fn write_type(&mut self, value: &T) -> IoResult { + /// + /// Prefer `write_pod()` when possible, because `write_type()` may cause + /// undefined behavior if `value` contains uninitialized bytes. + /// + /// # Safety + /// + /// Caller must ensure casting T to bytes is safe. + /// Refer to the Safety sections in std::slice::from_raw_parts() + /// and bytemuck's Pod and NoUninit for more information. + pub unsafe fn write_type(&mut self, value: &T) -> IoResult { let size = mem::size_of::(); let ptr = value as *const _ as *const u8; + // SAFETY: The caller ensures that `value` contains no uninitialized bytes, + // we ensure the size is safe by querying T directly, + // and Rust ensures all values are at least byte-aligned. let slice = unsafe { std::slice::from_raw_parts(ptr, size) }; self.write(slice)?; Ok(size) @@ -73,10 +93,10 @@ impl ByteBlockWriter { ) -> IoResult { let mut size = 0; if let Some(rent_epoch) = opt_fields.rent_epoch { - size += self.write_type(&rent_epoch)?; + size += self.write_pod(&rent_epoch)?; } if let Some(hash) = opt_fields.account_hash { - size += self.write_type(&hash)?; + size += self.write_pod(&hash)?; } debug_assert_eq!(size, opt_fields.size()); @@ -112,18 +132,40 @@ impl ByteBlockWriter { /// The util struct for reading byte blocks. pub struct ByteBlockReader; +/// Reads the raw part of the input byte_block, at the specified offset, as type T. +/// +/// Returns None if `offset` + size_of::() exceeds the size of the input byte_block. +/// +/// Type T must be plain ol' data to ensure no undefined behavior. +pub fn read_pod(byte_block: &[u8], offset: usize) -> Option<&T> { + // SAFETY: Since T is AnyBitPattern, it is safe to cast bytes to T. + unsafe { read_type(byte_block, offset) } +} + /// Reads the raw part of the input byte_block at the specified offset /// as type T. /// /// If `offset` + size_of::() exceeds the size of the input byte_block, /// then None will be returned. -pub fn read_type(byte_block: &[u8], offset: usize) -> Option<&T> { +/// +/// Prefer `read_pod()` when possible, because `read_type()` may cause +/// undefined behavior. +/// +/// # Safety +/// +/// Caller must ensure casting bytes to T is safe. +/// Refer to the Safety sections in std::slice::from_raw_parts() +/// and bytemuck's Pod and AnyBitPattern for more information. +pub unsafe fn read_type(byte_block: &[u8], offset: usize) -> Option<&T> { let (next, overflow) = offset.overflowing_add(std::mem::size_of::()); if overflow || next > byte_block.len() { return None; } let ptr = byte_block[offset..].as_ptr() as *const T; debug_assert!(ptr as usize % std::mem::align_of::() == 0); + // SAFETY: The caller ensures it is safe to cast bytes to T, + // we ensure the size is safe by querying T directly, + // and we just checked above to ensure the ptr is aligned for T. Some(unsafe { &*ptr }) } @@ -169,7 +211,7 @@ mod tests { let mut writer = ByteBlockWriter::new(format); let value: u32 = 42; - writer.write_type(&value).unwrap(); + writer.write_pod(&value).unwrap(); assert_eq!(writer.raw_len(), mem::size_of::()); let buffer = writer.finish().unwrap(); @@ -231,12 +273,14 @@ mod tests { let test_data3 = [33u8; 300]; // Write the above meta and data in an interleaving way. - writer.write_type(&test_metas[0]).unwrap(); - writer.write_type(&test_data1).unwrap(); - writer.write_type(&test_metas[1]).unwrap(); - writer.write_type(&test_data2).unwrap(); - writer.write_type(&test_metas[2]).unwrap(); - writer.write_type(&test_data3).unwrap(); + unsafe { + writer.write_type(&test_metas[0]).unwrap(); + writer.write_type(&test_data1).unwrap(); + writer.write_type(&test_metas[1]).unwrap(); + writer.write_type(&test_data2).unwrap(); + writer.write_type(&test_metas[2]).unwrap(); + writer.write_type(&test_data3).unwrap(); + } assert_eq!( writer.raw_len(), mem::size_of::() * 3 @@ -346,13 +390,13 @@ mod tests { let mut offset = 0; for opt_fields in &opt_fields_vec { if let Some(expected_rent_epoch) = opt_fields.rent_epoch { - let rent_epoch = read_type::(&decoded_buffer, offset).unwrap(); + let rent_epoch = read_pod::(&decoded_buffer, offset).unwrap(); assert_eq!(*rent_epoch, expected_rent_epoch); verified_count += 1; offset += std::mem::size_of::(); } if let Some(expected_hash) = opt_fields.account_hash { - let hash = read_type::(&decoded_buffer, offset).unwrap(); + let hash = read_pod::(&decoded_buffer, offset).unwrap(); assert_eq!(hash, &expected_hash); verified_count += 1; offset += std::mem::size_of::(); diff --git a/accounts-db/src/tiered_storage/file.rs b/accounts-db/src/tiered_storage/file.rs index 07faa3e81..51801c613 100644 --- a/accounts-db/src/tiered_storage/file.rs +++ b/accounts-db/src/tiered_storage/file.rs @@ -1,8 +1,11 @@ -use std::{ - fs::{File, OpenOptions}, - io::{Read, Result as IoResult, Seek, SeekFrom, Write}, - mem, - path::Path, +use { + bytemuck::{AnyBitPattern, NoUninit}, + std::{ + fs::{File, OpenOptions}, + io::{Read, Result as IoResult, Seek, SeekFrom, Write}, + mem, + path::Path, + }, }; #[derive(Debug)] @@ -33,14 +36,53 @@ impl TieredStorageFile { )) } - pub fn write_type(&self, value: &T) -> IoResult { + /// Writes `value` to the file. + /// + /// `value` must be plain ol' data. + pub fn write_pod(&self, value: &T) -> IoResult { + // SAFETY: Since T is NoUninit, it does not contain any uninitialized bytes. + unsafe { self.write_type(value) } + } + + /// Writes `value` to the file. + /// + /// Prefer `write_pod` when possible, because `write_value` may cause + /// undefined behavior if `value` contains uninitialized bytes. + /// + /// # Safety + /// + /// Caller must ensure casting T to bytes is safe. + /// Refer to the Safety sections in std::slice::from_raw_parts() + /// and bytemuck's Pod and NoUninit for more information. + pub unsafe fn write_type(&self, value: &T) -> IoResult { let ptr = value as *const _ as *const u8; let bytes = unsafe { std::slice::from_raw_parts(ptr, mem::size_of::()) }; self.write_bytes(bytes) } - pub fn read_type(&self, value: &mut T) -> IoResult<()> { + /// Reads a value of type `T` from the file. + /// + /// Type T must be plain ol' data. + pub fn read_pod(&self, value: &mut T) -> IoResult<()> { + // SAFETY: Since T is AnyBitPattern, it is safe to cast bytes to T. + unsafe { self.read_type(value) } + } + + /// Reads a value of type `T` from the file. + /// + /// Prefer `read_pod()` when possible, because `read_type()` may cause + /// undefined behavior. + /// + /// # Safety + /// + /// Caller must ensure casting bytes to T is safe. + /// Refer to the Safety sections in std::slice::from_raw_parts() + /// and bytemuck's Pod and AnyBitPattern for more information. + pub unsafe fn read_type(&self, value: &mut T) -> IoResult<()> { let ptr = value as *mut _ as *mut u8; + // SAFETY: The caller ensures it is safe to cast bytes to T, + // we ensure the size is safe by querying T directly, + // and Rust ensures ptr is aligned. let bytes = unsafe { std::slice::from_raw_parts_mut(ptr, mem::size_of::()) }; self.read_bytes(bytes) } diff --git a/accounts-db/src/tiered_storage/footer.rs b/accounts-db/src/tiered_storage/footer.rs index f7b885e1a..d5c8176ec 100644 --- a/accounts-db/src/tiered_storage/footer.rs +++ b/accounts-db/src/tiered_storage/footer.rs @@ -1,7 +1,10 @@ use { crate::tiered_storage::{ - error::TieredStorageError, file::TieredStorageFile, index::IndexBlockFormat, - mmap_utils::get_type, TieredStorageResult, + error::TieredStorageError, + file::TieredStorageFile, + index::IndexBlockFormat, + mmap_utils::{get_pod, get_type}, + TieredStorageResult, }, bytemuck::{Pod, Zeroable}, memmap2::Mmap, @@ -204,8 +207,9 @@ impl TieredStorageFooter { } pub fn write_footer_block(&self, file: &TieredStorageFile) -> TieredStorageResult<()> { - file.write_type(self)?; - file.write_type(&TieredStorageMagicNumber::default())?; + // SAFETY: The footer does not contain any uninitialized bytes. + unsafe { file.write_type(self)? }; + file.write_pod(&TieredStorageMagicNumber::default())?; Ok(()) } @@ -214,13 +218,13 @@ impl TieredStorageFooter { file.seek_from_end(-(FOOTER_TAIL_SIZE as i64))?; let mut footer_version: u64 = 0; - file.read_type(&mut footer_version)?; + file.read_pod(&mut footer_version)?; if footer_version != FOOTER_FORMAT_VERSION { return Err(TieredStorageError::InvalidFooterVersion(footer_version)); } let mut footer_size: u64 = 0; - file.read_type(&mut footer_size)?; + file.read_pod(&mut footer_size)?; if footer_size != FOOTER_SIZE as u64 { return Err(TieredStorageError::InvalidFooterSize( footer_size, @@ -229,7 +233,7 @@ impl TieredStorageFooter { } let mut magic_number = TieredStorageMagicNumber::zeroed(); - file.read_type(&mut magic_number)?; + file.read_pod(&mut magic_number)?; if magic_number != TieredStorageMagicNumber::default() { return Err(TieredStorageError::MagicNumberMismatch( TieredStorageMagicNumber::default().0, @@ -239,7 +243,9 @@ impl TieredStorageFooter { let mut footer = Self::default(); file.seek_from_end(-(footer_size as i64))?; - file.read_type(&mut footer)?; + // SAFETY: We sanitize the footer to ensure all the bytes are + // actually safe to interpret as a TieredStorageFooter. + unsafe { file.read_type(&mut footer)? }; Self::sanitize(&footer)?; Ok(footer) @@ -248,12 +254,12 @@ impl TieredStorageFooter { pub fn new_from_mmap(mmap: &Mmap) -> TieredStorageResult<&TieredStorageFooter> { let offset = mmap.len().saturating_sub(FOOTER_TAIL_SIZE); - let (footer_version, offset) = get_type::(mmap, offset)?; + let (footer_version, offset) = get_pod::(mmap, offset)?; if *footer_version != FOOTER_FORMAT_VERSION { return Err(TieredStorageError::InvalidFooterVersion(*footer_version)); } - let (&footer_size, offset) = get_type::(mmap, offset)?; + let (&footer_size, offset) = get_pod::(mmap, offset)?; if footer_size != FOOTER_SIZE as u64 { return Err(TieredStorageError::InvalidFooterSize( footer_size, @@ -261,7 +267,7 @@ impl TieredStorageFooter { )); } - let (magic_number, _offset) = get_type::(mmap, offset)?; + let (magic_number, _offset) = get_pod::(mmap, offset)?; if *magic_number != TieredStorageMagicNumber::default() { return Err(TieredStorageError::MagicNumberMismatch( TieredStorageMagicNumber::default().0, @@ -270,7 +276,9 @@ impl TieredStorageFooter { } let footer_offset = mmap.len().saturating_sub(footer_size as usize); - let (footer, _offset) = get_type::(mmap, footer_offset)?; + // SAFETY: We sanitize the footer to ensure all the bytes are + // actually safe to interpret as a TieredStorageFooter. + let (footer, _offset) = unsafe { get_type::(mmap, footer_offset)? }; Self::sanitize(footer)?; Ok(footer) diff --git a/accounts-db/src/tiered_storage/hot.rs b/accounts-db/src/tiered_storage/hot.rs index d268266ae..c460d3b9b 100644 --- a/accounts-db/src/tiered_storage/hot.rs +++ b/accounts-db/src/tiered_storage/hot.rs @@ -10,7 +10,7 @@ use { }, index::{AccountOffset, IndexBlockFormat, IndexOffset}, meta::{AccountMetaFlags, AccountMetaOptionalFields, TieredAccountMeta}, - mmap_utils::get_type, + mmap_utils::get_pod, owners::{OwnerOffset, OwnersBlock}, TieredStorageError, TieredStorageFormat, TieredStorageResult, }, @@ -204,7 +204,7 @@ impl TieredAccountMeta for HotAccountMeta { .then(|| { let offset = self.optional_fields_offset(account_block) + AccountMetaOptionalFields::rent_epoch_offset(self.flags()); - byte_block::read_type::(account_block, offset).copied() + byte_block::read_pod::(account_block, offset).copied() }) .flatten() } @@ -217,7 +217,7 @@ impl TieredAccountMeta for HotAccountMeta { .then(|| { let offset = self.optional_fields_offset(account_block) + AccountMetaOptionalFields::account_hash_offset(self.flags()); - byte_block::read_type::(account_block, offset) + byte_block::read_pod::(account_block, offset) }) .flatten() } @@ -283,7 +283,7 @@ impl HotStorageReader { ) -> TieredStorageResult<&HotAccountMeta> { let internal_account_offset = account_offset.offset(); - let (meta, _) = get_type::(&self.mmap, internal_account_offset)?; + let (meta, _) = get_pod::(&self.mmap, internal_account_offset)?; Ok(meta) } @@ -452,13 +452,16 @@ pub mod tests { .with_flags(&flags); let mut writer = ByteBlockWriter::new(AccountBlockFormat::AlignedRaw); - writer.write_type(&expected_meta).unwrap(); - writer.write_type(&account_data).unwrap(); - writer.write_type(&padding).unwrap(); + writer.write_pod(&expected_meta).unwrap(); + // SAFETY: These values are POD, so they are safe to write. + unsafe { + writer.write_type(&account_data).unwrap(); + writer.write_type(&padding).unwrap(); + } writer.write_optional_fields(&optional_fields).unwrap(); let buffer = writer.finish().unwrap(); - let meta = byte_block::read_type::(&buffer, 0).unwrap(); + let meta = byte_block::read_pod::(&buffer, 0).unwrap(); assert_eq!(expected_meta, *meta); assert!(meta.flags().has_rent_epoch()); assert!(meta.flags().has_account_hash()); @@ -548,7 +551,7 @@ pub mod tests { .iter() .map(|meta| { let prev_offset = current_offset; - current_offset += file.write_type(meta).unwrap(); + current_offset += file.write_pod(meta).unwrap(); HotAccountOffset::new(prev_offset).unwrap() }) .collect(); diff --git a/accounts-db/src/tiered_storage/index.rs b/accounts-db/src/tiered_storage/index.rs index 606ee1f90..06bb48491 100644 --- a/accounts-db/src/tiered_storage/index.rs +++ b/accounts-db/src/tiered_storage/index.rs @@ -1,6 +1,6 @@ use { crate::tiered_storage::{ - file::TieredStorageFile, footer::TieredStorageFooter, mmap_utils::get_type, + file::TieredStorageFile, footer::TieredStorageFooter, mmap_utils::get_pod, TieredStorageResult, }, bytemuck::{Pod, Zeroable}, @@ -66,10 +66,10 @@ impl IndexBlockFormat { Self::AddressAndBlockOffsetOnly => { let mut bytes_written = 0; for index_entry in index_entries { - bytes_written += file.write_type(index_entry.address)?; + bytes_written += file.write_pod(index_entry.address)?; } for index_entry in index_entries { - bytes_written += file.write_type(&index_entry.offset)?; + bytes_written += file.write_pod(&index_entry.offset)?; } Ok(bytes_written) } @@ -89,7 +89,7 @@ impl IndexBlockFormat { + std::mem::size_of::() * (index_offset.0 as usize) } }; - let (address, _) = get_type::(mmap, account_offset)?; + let (address, _) = get_pod::(mmap, account_offset)?; Ok(address) } @@ -105,7 +105,7 @@ impl IndexBlockFormat { let offset = footer.index_block_offset as usize + std::mem::size_of::() * footer.account_entry_count as usize + std::mem::size_of::() * index_offset.0 as usize; - let (account_offset, _) = get_type::(mmap, offset)?; + let (account_offset, _) = get_pod::(mmap, offset)?; Ok(*account_offset) } diff --git a/accounts-db/src/tiered_storage/mmap_utils.rs b/accounts-db/src/tiered_storage/mmap_utils.rs index 24d696a1e..610384efd 100644 --- a/accounts-db/src/tiered_storage/mmap_utils.rs +++ b/accounts-db/src/tiered_storage/mmap_utils.rs @@ -5,10 +5,30 @@ use { std::io::Result as IoResult, }; -pub fn get_type(map: &Mmap, offset: usize) -> IoResult<(&T, usize)> { - let (data, next) = get_slice(map, offset, std::mem::size_of::())?; +/// Borrows a value of type `T` from `mmap` +/// +/// Type T must be plain ol' data to ensure no undefined behavior. +pub fn get_pod(mmap: &Mmap, offset: usize) -> IoResult<(&T, usize)> { + // SAFETY: Since T is AnyBitPattern, it is safe to cast bytes to T. + unsafe { get_type::(mmap, offset) } +} + +/// Borrows a value of type `T` from `mmap` +/// +/// Prefer `get_pod()` when possible, because `get_type()` may cause undefined behavior. +/// +/// # Safety +/// +/// Caller must ensure casting bytes to T is safe. +/// Refer to the Safety sections in std::slice::from_raw_parts() +/// and bytemuck's Pod and AnyBitPattern for more information. +pub unsafe fn get_type(mmap: &Mmap, offset: usize) -> IoResult<(&T, usize)> { + let (data, next) = get_slice(mmap, offset, std::mem::size_of::())?; let ptr = data.as_ptr() as *const T; debug_assert!(ptr as usize % std::mem::align_of::() == 0); + // SAFETY: The caller ensures it is safe to cast bytes to T, + // we ensure the size is safe by querying T directly, + // and we just checked above to ensure the ptr is aligned for T. Ok((unsafe { &*ptr }, next)) } @@ -34,5 +54,7 @@ pub fn get_slice(mmap: &Mmap, offset: usize, size: usize) -> IoResult<(&[u8], us let next = u64_align!(next); let ptr = data.as_ptr(); + // SAFETY: The Mmap ensures the bytes are safe the read, and we just checked + // to ensure we don't read past the end of the internal buffer. Ok((unsafe { std::slice::from_raw_parts(ptr, size) }, next)) } diff --git a/accounts-db/src/tiered_storage/owners.rs b/accounts-db/src/tiered_storage/owners.rs index 7cd548e3a..41e1f8a67 100644 --- a/accounts-db/src/tiered_storage/owners.rs +++ b/accounts-db/src/tiered_storage/owners.rs @@ -1,6 +1,6 @@ use { crate::tiered_storage::{ - file::TieredStorageFile, footer::TieredStorageFooter, mmap_utils::get_type, + file::TieredStorageFile, footer::TieredStorageFooter, mmap_utils::get_pod, TieredStorageResult, }, memmap2::Mmap, @@ -32,7 +32,7 @@ impl OwnersBlock { ) -> TieredStorageResult { let mut bytes_written = 0; for address in addresses { - bytes_written += file.write_type(address)?; + bytes_written += file.write_pod(address)?; } Ok(bytes_written) @@ -47,7 +47,7 @@ impl OwnersBlock { ) -> TieredStorageResult<&'a Pubkey> { let offset = footer.owners_block_offset as usize + (std::mem::size_of::() * owner_offset.0 as usize); - let (pubkey, _) = get_type::(mmap, offset)?; + let (pubkey, _) = get_pod::(mmap, offset)?; Ok(pubkey) }