heap_size type to be consistent with request instruction (#33354)

* heap_size type to be consistent with request instruction

* update tests
This commit is contained in:
Tao Zhu 2023-09-25 13:11:26 -05:00 committed by GitHub
parent 08aba38d35
commit 57e78a16dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 34 deletions

View File

@ -102,7 +102,7 @@ pub struct ComputeBudget {
/// The total cost is calculated as `msm_base_cost + (length - 1) * msm_incremental_cost`.
pub curve25519_ristretto_msm_incremental_cost: u64,
/// program heap region size, default: solana_sdk::entrypoint::HEAP_LENGTH
pub heap_size: usize,
pub heap_size: u32,
/// Number of compute units per additional 32k heap above the default (~.5
/// us per 32k at 15 units/us rounded up)
pub heap_cost: u64,
@ -179,7 +179,7 @@ impl ComputeBudget {
curve25519_ristretto_multiply_cost: 2_208,
curve25519_ristretto_msm_base_cost: 2303,
curve25519_ristretto_msm_incremental_cost: 788,
heap_size: solana_sdk::entrypoint::HEAP_LENGTH,
heap_size: u32::try_from(solana_sdk::entrypoint::HEAP_LENGTH).unwrap(),
heap_cost: 8,
mem_op_base_cost: 10,
alt_bn128_addition_cost: 334,
@ -279,7 +279,7 @@ impl ComputeBudget {
InstructionError::InvalidInstructionData,
));
}
self.heap_size = bytes as usize;
self.heap_size = bytes;
}
let compute_unit_limit = updated_compute_unit_limit
@ -524,7 +524,7 @@ mod tests {
Ok(PrioritizationFeeDetails::default()),
ComputeBudget {
compute_unit_limit: DEFAULT_INSTRUCTION_COMPUTE_UNIT_LIMIT as u64,
heap_size: MAX_HEAP_FRAME_BYTES as usize,
heap_size: MAX_HEAP_FRAME_BYTES,
..ComputeBudget::default()
}
);
@ -574,7 +574,7 @@ mod tests {
)),
ComputeBudget {
compute_unit_limit: MAX_COMPUTE_UNIT_LIMIT as u64,
heap_size: MAX_HEAP_FRAME_BYTES as usize,
heap_size: MAX_HEAP_FRAME_BYTES,
..ComputeBudget::default()
}
);
@ -592,7 +592,7 @@ mod tests {
)),
ComputeBudget {
compute_unit_limit: 1,
heap_size: MAX_HEAP_FRAME_BYTES as usize,
heap_size: MAX_HEAP_FRAME_BYTES,
..ComputeBudget::default()
}
);

View File

@ -191,10 +191,10 @@ pub fn check_loader_id(id: &Pubkey) -> bool {
}
/// Only used in macro, do not use directly!
pub fn calculate_heap_cost(heap_size: u64, heap_cost: u64, enable_rounding_fix: bool) -> u64 {
pub fn calculate_heap_cost(heap_size: u32, heap_cost: u64, enable_rounding_fix: bool) -> u64 {
const KIBIBYTE: u64 = 1024;
const PAGE_SIZE_KB: u64 = 32;
let mut rounded_heap_size = heap_size;
let mut rounded_heap_size = u64::from(heap_size);
if enable_rounding_fix {
rounded_heap_size = rounded_heap_size
.saturating_add(PAGE_SIZE_KB.saturating_mul(KIBIBYTE).saturating_sub(1));
@ -267,7 +267,7 @@ macro_rules! create_vm {
.feature_set
.is_active(&solana_sdk::feature_set::round_up_heap_size::id());
let mut heap_cost_result = invoke_context.consume_checked($crate::calculate_heap_cost(
heap_size as u64,
heap_size,
invoke_context.get_compute_budget().heap_cost,
round_up_heap_size,
));
@ -281,7 +281,7 @@ macro_rules! create_vm {
>::zero_filled(stack_size);
let mut heap = solana_rbpf::aligned_memory::AlignedMemory::<
{ solana_rbpf::ebpf::HOST_ALIGN },
>::zero_filled(heap_size);
>::zero_filled(usize::try_from(heap_size).unwrap());
let vm = $crate::create_vm(
$program,
$regions,
@ -4033,40 +4033,31 @@ mod tests {
// when `enable_heap_size_round_up` not enabled:
{
// assert less than 32K heap should cost zero unit
assert_eq!(0, calculate_heap_cost(31_u64 * 1024, heap_cost, false));
assert_eq!(0, calculate_heap_cost(31 * 1024, heap_cost, false));
// assert exact 32K heap should be cost zero unit
assert_eq!(0, calculate_heap_cost(32_u64 * 1024, heap_cost, false));
assert_eq!(0, calculate_heap_cost(32 * 1024, heap_cost, false));
// assert slightly more than 32K heap is mistakenly cost zero unit
assert_eq!(0, calculate_heap_cost(33_u64 * 1024, heap_cost, false));
assert_eq!(0, calculate_heap_cost(33 * 1024, heap_cost, false));
// assert exact 64K heap should cost 1 * heap_cost
assert_eq!(
heap_cost,
calculate_heap_cost(64_u64 * 1024, heap_cost, false)
);
assert_eq!(heap_cost, calculate_heap_cost(64 * 1024, heap_cost, false));
}
// when `enable_heap_size_round_up` is enabled:
{
// assert less than 32K heap should cost zero unit
assert_eq!(0, calculate_heap_cost(31_u64 * 1024, heap_cost, true));
assert_eq!(0, calculate_heap_cost(31 * 1024, heap_cost, true));
// assert exact 32K heap should be cost zero unit
assert_eq!(0, calculate_heap_cost(32_u64 * 1024, heap_cost, true));
assert_eq!(0, calculate_heap_cost(32 * 1024, heap_cost, true));
// assert slightly more than 32K heap should cost 1 * heap_cost
assert_eq!(
heap_cost,
calculate_heap_cost(33_u64 * 1024, heap_cost, true)
);
assert_eq!(heap_cost, calculate_heap_cost(33 * 1024, heap_cost, true));
// assert exact 64K heap should cost 1 * heap_cost
assert_eq!(
heap_cost,
calculate_heap_cost(64_u64 * 1024, heap_cost, true)
);
assert_eq!(heap_cost, calculate_heap_cost(64 * 1024, heap_cost, true));
}
}

View File

@ -94,10 +94,10 @@ pub fn create_program_runtime_environment_v2<'a>(
BuiltinProgram::new_loader(config, FunctionRegistry::default())
}
fn calculate_heap_cost(heap_size: u64, heap_cost: u64) -> u64 {
fn calculate_heap_cost(heap_size: u32, heap_cost: u64) -> u64 {
const KIBIBYTE: u64 = 1024;
const PAGE_SIZE_KB: u64 = 32;
heap_size
u64::from(heap_size)
.saturating_add(PAGE_SIZE_KB.saturating_mul(KIBIBYTE).saturating_sub(1))
.checked_div(PAGE_SIZE_KB.saturating_mul(KIBIBYTE))
.expect("PAGE_SIZE_KB * KIBIBYTE > 0")
@ -114,12 +114,11 @@ pub fn create_vm<'a, 'b>(
let sbpf_version = program.get_sbpf_version();
let compute_budget = invoke_context.get_compute_budget();
let heap_size = compute_budget.heap_size;
invoke_context.consume_checked(calculate_heap_cost(
heap_size as u64,
compute_budget.heap_cost,
))?;
invoke_context.consume_checked(calculate_heap_cost(heap_size, compute_budget.heap_cost))?;
let mut stack = AlignedMemory::<{ ebpf::HOST_ALIGN }>::zero_filled(config.stack_size());
let mut heap = AlignedMemory::<{ ebpf::HOST_ALIGN }>::zero_filled(compute_budget.heap_size);
let mut heap = AlignedMemory::<{ ebpf::HOST_ALIGN }>::zero_filled(
usize::try_from(compute_budget.heap_size).unwrap(),
);
let stack_len = stack.len();
let regions: Vec<MemoryRegion> = vec![
program.get_ro_region(),