direct_mapping: fix iter_memory_pair_chunks in reverse mode (#34204)

iter_memory_pair_chunks was iterating regions in reverse, but not memory
_within_ regions in reverse.

This commit fixes the issue and simplifies the implementation by removing
nested loops which made control flow hard to reason about.
This commit is contained in:
Alessandro Decina 2023-11-28 08:46:20 +11:00 committed by GitHub
parent 8445246b8f
commit 09088822e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 98 deletions

1
Cargo.lock generated
View File

@ -5521,6 +5521,7 @@ dependencies = [
"solana-sdk",
"solana-zk-token-sdk",
"solana_rbpf",
"test-case",
"thiserror",
]

View File

@ -27,6 +27,7 @@ assert_matches = { workspace = true }
memoffset = { workspace = true }
rand = { workspace = true }
solana-sdk = { workspace = true, features = ["dev-context-only-utils"] }
test-case = { workspace = true }
[lib]
crate-type = ["lib"]

View File

@ -289,8 +289,8 @@ fn iter_memory_pair_chunks<T, F>(
src_access: AccessType,
src_addr: u64,
dst_access: AccessType,
mut dst_addr: u64,
n: u64,
dst_addr: u64,
n_bytes: u64,
memory_mapping: &MemoryMapping,
reverse: bool,
mut fun: F,
@ -299,52 +299,90 @@ where
T: Default,
F: FnMut(*const u8, *const u8, usize) -> Result<T, Error>,
{
let mut src_chunk_iter = MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n)
.map_err(EbpfError::from)?;
loop {
// iterate source chunks
let (src_region, src_vm_addr, mut src_len) = match if reverse {
src_chunk_iter.next_back()
} else {
src_chunk_iter.next()
} {
Some(item) => item?,
None => break,
};
let mut src_host_addr = Result::from(src_region.vm_to_host(src_vm_addr, src_len as u64))?;
let mut dst_chunk_iter = MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n)
let mut src_chunk_iter =
MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n_bytes)
.map_err(EbpfError::from)?;
// iterate over destination chunks until this source chunk has been completely copied
while src_len > 0 {
loop {
let (dst_region, dst_vm_addr, dst_len) = match if reverse {
dst_chunk_iter.next_back()
let mut dst_chunk_iter =
MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n_bytes)
.map_err(EbpfError::from)?;
let mut src_chunk = None;
let mut dst_chunk = None;
macro_rules! memory_chunk {
($chunk_iter:ident, $chunk:ident) => {
if let Some($chunk) = &mut $chunk {
// Keep processing the current chunk
$chunk
} else {
// This is either the first call or we've processed all the bytes in the current
// chunk. Move to the next one.
let chunk = match if reverse {
$chunk_iter.next_back()
} else {
dst_chunk_iter.next()
$chunk_iter.next()
} {
Some(item) => item?,
None => break,
};
let dst_host_addr =
Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?;
let chunk_len = src_len.min(dst_len);
fun(
src_host_addr as *const u8,
dst_host_addr as *const u8,
chunk_len,
)?;
src_len = src_len.saturating_sub(chunk_len);
if reverse {
dst_addr = dst_addr.saturating_sub(chunk_len as u64);
} else {
dst_addr = dst_addr.saturating_add(chunk_len as u64);
}
if src_len == 0 {
break;
}
src_host_addr = src_host_addr.saturating_add(chunk_len as u64);
$chunk.insert(chunk)
}
};
}
loop {
let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk);
let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk);
// We always process same-length pairs
let chunk_len = *src_remaining.min(dst_remaining);
let (src_host_addr, dst_host_addr) = {
let (src_addr, dst_addr) = if reverse {
// When scanning backwards not only we want to scan regions from the end,
// we want to process the memory within regions backwards as well.
(
src_chunk_addr
.saturating_add(*src_remaining as u64)
.saturating_sub(chunk_len as u64),
dst_chunk_addr
.saturating_add(*dst_remaining as u64)
.saturating_sub(chunk_len as u64),
)
} else {
(*src_chunk_addr, *dst_chunk_addr)
};
(
Result::from(src_region.vm_to_host(src_addr, chunk_len as u64))?,
Result::from(dst_region.vm_to_host(dst_addr, chunk_len as u64))?,
)
};
fun(
src_host_addr as *const u8,
dst_host_addr as *const u8,
chunk_len,
)?;
// Update how many bytes we have left to scan in each chunk
*src_remaining = src_remaining.saturating_sub(chunk_len);
*dst_remaining = dst_remaining.saturating_sub(chunk_len);
if !reverse {
// We've scanned `chunk_len` bytes so we move the vm address forward. In reverse
// mode we don't do this since we make progress by decreasing src_len and
// dst_len.
*src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64);
*dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64);
}
if *src_remaining == 0 {
src_chunk = None;
}
if *dst_remaining == 0 {
dst_chunk = None;
}
}
@ -471,11 +509,13 @@ impl<'a> DoubleEndedIterator for MemoryChunkIterator<'a> {
#[cfg(test)]
#[allow(clippy::indexing_slicing)]
#[allow(clippy::arithmetic_side_effects)]
mod tests {
use {
super::*,
assert_matches::assert_matches,
solana_rbpf::{ebpf::MM_PROGRAM_START, program::SBPFVersion},
test_case::test_case,
};
fn to_chunk_vec<'a>(
@ -734,72 +774,59 @@ mod tests {
memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 8, 4, &memory_mapping).unwrap();
}
#[test]
fn test_overlapping_memmove_non_contiguous_right() {
#[test_case(&[], (0, 0, 0); "no regions")]
#[test_case(&[10], (1, 10, 0); "single region 0 len")]
#[test_case(&[10], (0, 5, 5); "single region no overlap")]
#[test_case(&[10], (0, 0, 10) ; "single region complete overlap")]
#[test_case(&[10], (2, 0, 5); "single region partial overlap start")]
#[test_case(&[10], (0, 1, 6); "single region partial overlap middle")]
#[test_case(&[10], (2, 5, 5); "single region partial overlap end")]
#[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")]
#[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")]
#[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")]
#[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")]
#[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")]
#[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")]
#[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")]
#[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")]
#[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")]
#[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")]
#[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")]
#[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")]
fn test_memmove_non_contiguous(
regions: &[usize],
(src_offset, dst_offset, len): (usize, usize, usize),
) {
let config = Config {
aligned_memory_mapping: false,
..Config::default()
};
let mem1 = vec![0x11; 1];
let mut mem2 = vec![0x22; 2];
let mut mem3 = vec![0x33; 3];
let mut mem4 = vec![0x44; 4];
let memory_mapping = MemoryMapping::new(
vec![
MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START),
MemoryRegion::new_writable(&mut mem2, MM_PROGRAM_START + 1),
MemoryRegion::new_writable(&mut mem3, MM_PROGRAM_START + 3),
MemoryRegion::new_writable(&mut mem4, MM_PROGRAM_START + 6),
],
&config,
&SBPFVersion::V2,
)
.unwrap();
let (mem, memory_mapping) = build_memory_mapping(regions, &config);
// overlapping memmove right - the implementation will copy backwards
assert_eq!(
memmove_non_contiguous(MM_PROGRAM_START + 1, MM_PROGRAM_START, 7, &memory_mapping)
.unwrap(),
0
);
assert_eq!(&mem1, &[0x11]);
assert_eq!(&mem2, &[0x11, 0x22]);
assert_eq!(&mem3, &[0x22, 0x33, 0x33]);
assert_eq!(&mem4, &[0x33, 0x44, 0x44, 0x44]);
}
#[test]
fn test_overlapping_memmove_non_contiguous_left() {
let config = Config {
aligned_memory_mapping: false,
..Config::default()
// flatten the memory so we can memmove it with ptr::copy
let mut expected_memory = flatten_memory(&mem);
unsafe {
std::ptr::copy(
expected_memory.as_ptr().add(src_offset),
expected_memory.as_mut_ptr().add(dst_offset),
len,
)
};
let mut mem1 = vec![0x11; 1];
let mut mem2 = vec![0x22; 2];
let mut mem3 = vec![0x33; 3];
let mut mem4 = vec![0x44; 4];
let memory_mapping = MemoryMapping::new(
vec![
MemoryRegion::new_writable(&mut mem1, MM_PROGRAM_START),
MemoryRegion::new_writable(&mut mem2, MM_PROGRAM_START + 1),
MemoryRegion::new_writable(&mut mem3, MM_PROGRAM_START + 3),
MemoryRegion::new_writable(&mut mem4, MM_PROGRAM_START + 6),
],
&config,
&SBPFVersion::V2,
// do our memmove
memmove_non_contiguous(
MM_PROGRAM_START + dst_offset as u64,
MM_PROGRAM_START + src_offset as u64,
len as u64,
&memory_mapping,
)
.unwrap();
// overlapping memmove left - the implementation will copy forward
assert_eq!(
memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 1, 7, &memory_mapping)
.unwrap(),
0
);
assert_eq!(&mem1, &[0x22]);
assert_eq!(&mem2, &[0x22, 0x33]);
assert_eq!(&mem3, &[0x33, 0x33, 0x44]);
assert_eq!(&mem4, &[0x44, 0x44, 0x44, 0x44]);
// flatten memory post our memmove
let memory = flatten_memory(&mem);
// compare libc's memmove with ours
assert_eq!(expected_memory, memory);
}
#[test]
@ -910,4 +937,33 @@ mod tests {
unsafe { memcmp(b"oobar", b"obarb", 5) }
);
}
fn build_memory_mapping<'a>(
regions: &[usize],
config: &'a Config,
) -> (Vec<Vec<u8>>, MemoryMapping<'a>) {
let mut regs = vec![];
let mut mem = Vec::new();
let mut offset = 0;
for (i, region_len) in regions.iter().enumerate() {
mem.push(
(0..*region_len)
.map(|x| (i * 10 + x) as u8)
.collect::<Vec<_>>(),
);
regs.push(MemoryRegion::new_writable(
&mut mem[i],
MM_PROGRAM_START + offset as u64,
));
offset += *region_len;
}
let memory_mapping = MemoryMapping::new(regs, config, &SBPFVersion::V2).unwrap();
(mem, memory_mapping)
}
fn flatten_memory(mem: &[Vec<u8>]) -> Vec<u8> {
mem.iter().flatten().copied().collect()
}
}