direct_mapping: fix iter_memory_pair_chunks in reverse mode (#34204)
iter_memory_pair_chunks was iterating regions in reverse, but not memory _within_ regions in reverse. This commit fixes the issue and simplifies the implementation by removing nested loops which made control flow hard to reason about.
This commit is contained in:
parent
8445246b8f
commit
09088822e7
|
@ -5521,6 +5521,7 @@ dependencies = [
|
|||
"solana-sdk",
|
||||
"solana-zk-token-sdk",
|
||||
"solana_rbpf",
|
||||
"test-case",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ assert_matches = { workspace = true }
|
|||
memoffset = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
solana-sdk = { workspace = true, features = ["dev-context-only-utils"] }
|
||||
test-case = { workspace = true }
|
||||
|
||||
[lib]
|
||||
crate-type = ["lib"]
|
||||
|
|
|
@ -289,8 +289,8 @@ fn iter_memory_pair_chunks<T, F>(
|
|||
src_access: AccessType,
|
||||
src_addr: u64,
|
||||
dst_access: AccessType,
|
||||
mut dst_addr: u64,
|
||||
n: u64,
|
||||
dst_addr: u64,
|
||||
n_bytes: u64,
|
||||
memory_mapping: &MemoryMapping,
|
||||
reverse: bool,
|
||||
mut fun: F,
|
||||
|
@ -299,52 +299,90 @@ where
|
|||
T: Default,
|
||||
F: FnMut(*const u8, *const u8, usize) -> Result<T, Error>,
|
||||
{
|
||||
let mut src_chunk_iter = MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n)
|
||||
.map_err(EbpfError::from)?;
|
||||
loop {
|
||||
// iterate source chunks
|
||||
let (src_region, src_vm_addr, mut src_len) = match if reverse {
|
||||
src_chunk_iter.next_back()
|
||||
} else {
|
||||
src_chunk_iter.next()
|
||||
} {
|
||||
Some(item) => item?,
|
||||
None => break,
|
||||
};
|
||||
|
||||
let mut src_host_addr = Result::from(src_region.vm_to_host(src_vm_addr, src_len as u64))?;
|
||||
let mut dst_chunk_iter = MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n)
|
||||
let mut src_chunk_iter =
|
||||
MemoryChunkIterator::new(memory_mapping, src_access, src_addr, n_bytes)
|
||||
.map_err(EbpfError::from)?;
|
||||
// iterate over destination chunks until this source chunk has been completely copied
|
||||
while src_len > 0 {
|
||||
loop {
|
||||
let (dst_region, dst_vm_addr, dst_len) = match if reverse {
|
||||
dst_chunk_iter.next_back()
|
||||
let mut dst_chunk_iter =
|
||||
MemoryChunkIterator::new(memory_mapping, dst_access, dst_addr, n_bytes)
|
||||
.map_err(EbpfError::from)?;
|
||||
|
||||
let mut src_chunk = None;
|
||||
let mut dst_chunk = None;
|
||||
|
||||
macro_rules! memory_chunk {
|
||||
($chunk_iter:ident, $chunk:ident) => {
|
||||
if let Some($chunk) = &mut $chunk {
|
||||
// Keep processing the current chunk
|
||||
$chunk
|
||||
} else {
|
||||
// This is either the first call or we've processed all the bytes in the current
|
||||
// chunk. Move to the next one.
|
||||
let chunk = match if reverse {
|
||||
$chunk_iter.next_back()
|
||||
} else {
|
||||
dst_chunk_iter.next()
|
||||
$chunk_iter.next()
|
||||
} {
|
||||
Some(item) => item?,
|
||||
None => break,
|
||||
};
|
||||
let dst_host_addr =
|
||||
Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?;
|
||||
let chunk_len = src_len.min(dst_len);
|
||||
fun(
|
||||
src_host_addr as *const u8,
|
||||
dst_host_addr as *const u8,
|
||||
chunk_len,
|
||||
)?;
|
||||
src_len = src_len.saturating_sub(chunk_len);
|
||||
if reverse {
|
||||
dst_addr = dst_addr.saturating_sub(chunk_len as u64);
|
||||
} else {
|
||||
dst_addr = dst_addr.saturating_add(chunk_len as u64);
|
||||
}
|
||||
if src_len == 0 {
|
||||
break;
|
||||
}
|
||||
src_host_addr = src_host_addr.saturating_add(chunk_len as u64);
|
||||
$chunk.insert(chunk)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
loop {
|
||||
let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk);
|
||||
let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk);
|
||||
|
||||
// We always process same-length pairs
|
||||
let chunk_len = *src_remaining.min(dst_remaining);
|
||||
|
||||
let (src_host_addr, dst_host_addr) = {
|
||||
let (src_addr, dst_addr) = if reverse {
|
||||
// When scanning backwards not only we want to scan regions from the end,
|
||||
// we want to process the memory within regions backwards as well.
|
||||
(
|
||||
src_chunk_addr
|
||||
.saturating_add(*src_remaining as u64)
|
||||
.saturating_sub(chunk_len as u64),
|
||||
dst_chunk_addr
|
||||
.saturating_add(*dst_remaining as u64)
|
||||
.saturating_sub(chunk_len as u64),
|
||||
)
|
||||
} else {
|
||||
(*src_chunk_addr, *dst_chunk_addr)
|
||||
};
|
||||
|
||||
(
|
||||
Result::from(src_region.vm_to_host(src_addr, chunk_len as u64))?,
|
||||
Result::from(dst_region.vm_to_host(dst_addr, chunk_len as u64))?,
|
||||
)
|
||||
};
|
||||
|
||||
fun(
|
||||
src_host_addr as *const u8,
|
||||
dst_host_addr as *const u8,
|
||||
chunk_len,
|
||||
)?;
|
||||
|
||||
// Update how many bytes we have left to scan in each chunk
|
||||
*src_remaining = src_remaining.saturating_sub(chunk_len);
|
||||
*dst_remaining = dst_remaining.saturating_sub(chunk_len);
|
||||
|
||||
if !reverse {
|
||||
// We've scanned `chunk_len` bytes so we move the vm address forward. In reverse
|
||||
// mode we don't do this since we make progress by decreasing src_len and
|
||||
// dst_len.
|
||||
*src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64);
|
||||
*dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64);
|
||||
}
|
||||
|
||||
if *src_remaining == 0 {
|
||||
src_chunk = None;
|
||||
}
|
||||
|
||||
if *dst_remaining == 0 {
|
||||
dst_chunk = None;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -471,11 +509,13 @@ impl<'a> DoubleEndedIterator for MemoryChunkIterator<'a> {
|
|||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
#[allow(clippy::arithmetic_side_effects)]
|
||||
mod tests {
|
||||
use {
|
||||
super::*,
|
||||
assert_matches::assert_matches,
|
||||
solana_rbpf::{ebpf::MM_PROGRAM_START, program::SBPFVersion},
|
||||
test_case::test_case,
|
||||
};
|
||||
|
||||
fn to_chunk_vec<'a>(
|
||||
|
@ -734,72 +774,59 @@ mod tests {
|
|||
memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 8, 4, &memory_mapping).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_overlapping_memmove_non_contiguous_right() {
|
||||
#[test_case(&[], (0, 0, 0); "no regions")]
|
||||
#[test_case(&[10], (1, 10, 0); "single region 0 len")]
|
||||
#[test_case(&[10], (0, 5, 5); "single region no overlap")]
|
||||
#[test_case(&[10], (0, 0, 10) ; "single region complete overlap")]
|
||||
#[test_case(&[10], (2, 0, 5); "single region partial overlap start")]
|
||||
#[test_case(&[10], (0, 1, 6); "single region partial overlap middle")]
|
||||
#[test_case(&[10], (2, 5, 5); "single region partial overlap end")]
|
||||
#[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")]
|
||||
#[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")]
|
||||
#[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")]
|
||||
#[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")]
|
||||
#[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")]
|
||||
#[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")]
|
||||
#[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")]
|
||||
#[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")]
|
||||
#[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")]
|
||||
#[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")]
|
||||
#[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")]
|
||||
#[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")]
|
||||
fn test_memmove_non_contiguous(
|
||||
regions: &[usize],
|
||||
(src_offset, dst_offset, len): (usize, usize, usize),
|
||||
) {
|
||||
let config = Config {
|
||||
aligned_memory_mapping: false,
|
||||
..Config::default()
|
||||
};
|
||||
let mem1 = vec![0x11; 1];
|
||||
let mut mem2 = vec![0x22; 2];
|
||||
let mut mem3 = vec![0x33; 3];
|
||||
let mut mem4 = vec![0x44; 4];
|
||||
let memory_mapping = MemoryMapping::new(
|
||||
vec![
|
||||
MemoryRegion::new_readonly(&mem1, MM_PROGRAM_START),
|
||||
MemoryRegion::new_writable(&mut mem2, MM_PROGRAM_START + 1),
|
||||
MemoryRegion::new_writable(&mut mem3, MM_PROGRAM_START + 3),
|
||||
MemoryRegion::new_writable(&mut mem4, MM_PROGRAM_START + 6),
|
||||
],
|
||||
&config,
|
||||
&SBPFVersion::V2,
|
||||
)
|
||||
.unwrap();
|
||||
let (mem, memory_mapping) = build_memory_mapping(regions, &config);
|
||||
|
||||
// overlapping memmove right - the implementation will copy backwards
|
||||
assert_eq!(
|
||||
memmove_non_contiguous(MM_PROGRAM_START + 1, MM_PROGRAM_START, 7, &memory_mapping)
|
||||
.unwrap(),
|
||||
0
|
||||
);
|
||||
assert_eq!(&mem1, &[0x11]);
|
||||
assert_eq!(&mem2, &[0x11, 0x22]);
|
||||
assert_eq!(&mem3, &[0x22, 0x33, 0x33]);
|
||||
assert_eq!(&mem4, &[0x33, 0x44, 0x44, 0x44]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_overlapping_memmove_non_contiguous_left() {
|
||||
let config = Config {
|
||||
aligned_memory_mapping: false,
|
||||
..Config::default()
|
||||
// flatten the memory so we can memmove it with ptr::copy
|
||||
let mut expected_memory = flatten_memory(&mem);
|
||||
unsafe {
|
||||
std::ptr::copy(
|
||||
expected_memory.as_ptr().add(src_offset),
|
||||
expected_memory.as_mut_ptr().add(dst_offset),
|
||||
len,
|
||||
)
|
||||
};
|
||||
let mut mem1 = vec![0x11; 1];
|
||||
let mut mem2 = vec![0x22; 2];
|
||||
let mut mem3 = vec![0x33; 3];
|
||||
let mut mem4 = vec![0x44; 4];
|
||||
let memory_mapping = MemoryMapping::new(
|
||||
vec![
|
||||
MemoryRegion::new_writable(&mut mem1, MM_PROGRAM_START),
|
||||
MemoryRegion::new_writable(&mut mem2, MM_PROGRAM_START + 1),
|
||||
MemoryRegion::new_writable(&mut mem3, MM_PROGRAM_START + 3),
|
||||
MemoryRegion::new_writable(&mut mem4, MM_PROGRAM_START + 6),
|
||||
],
|
||||
&config,
|
||||
&SBPFVersion::V2,
|
||||
|
||||
// do our memmove
|
||||
memmove_non_contiguous(
|
||||
MM_PROGRAM_START + dst_offset as u64,
|
||||
MM_PROGRAM_START + src_offset as u64,
|
||||
len as u64,
|
||||
&memory_mapping,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// overlapping memmove left - the implementation will copy forward
|
||||
assert_eq!(
|
||||
memmove_non_contiguous(MM_PROGRAM_START, MM_PROGRAM_START + 1, 7, &memory_mapping)
|
||||
.unwrap(),
|
||||
0
|
||||
);
|
||||
assert_eq!(&mem1, &[0x22]);
|
||||
assert_eq!(&mem2, &[0x22, 0x33]);
|
||||
assert_eq!(&mem3, &[0x33, 0x33, 0x44]);
|
||||
assert_eq!(&mem4, &[0x44, 0x44, 0x44, 0x44]);
|
||||
// flatten memory post our memmove
|
||||
let memory = flatten_memory(&mem);
|
||||
|
||||
// compare libc's memmove with ours
|
||||
assert_eq!(expected_memory, memory);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -910,4 +937,33 @@ mod tests {
|
|||
unsafe { memcmp(b"oobar", b"obarb", 5) }
|
||||
);
|
||||
}
|
||||
|
||||
fn build_memory_mapping<'a>(
|
||||
regions: &[usize],
|
||||
config: &'a Config,
|
||||
) -> (Vec<Vec<u8>>, MemoryMapping<'a>) {
|
||||
let mut regs = vec![];
|
||||
let mut mem = Vec::new();
|
||||
let mut offset = 0;
|
||||
for (i, region_len) in regions.iter().enumerate() {
|
||||
mem.push(
|
||||
(0..*region_len)
|
||||
.map(|x| (i * 10 + x) as u8)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
regs.push(MemoryRegion::new_writable(
|
||||
&mut mem[i],
|
||||
MM_PROGRAM_START + offset as u64,
|
||||
));
|
||||
offset += *region_len;
|
||||
}
|
||||
|
||||
let memory_mapping = MemoryMapping::new(regs, config, &SBPFVersion::V2).unwrap();
|
||||
|
||||
(mem, memory_mapping)
|
||||
}
|
||||
|
||||
fn flatten_memory(mem: &[Vec<u8>]) -> Vec<u8> {
|
||||
mem.iter().flatten().copied().collect()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue