add curve25519 multiscalar multiplication syscall (#28216)
* add curve25519 multiscalar multiplication syscall * update compute unit costs * update tests * add update to compute budget * add syscall call function * update compute costs in tests * update syscall syntax
This commit is contained in:
parent
061bed0a8c
commit
3f63283eda
|
@ -69,6 +69,12 @@ pub struct ComputeBudget {
|
|||
pub curve25519_edwards_subtract_cost: u64,
|
||||
/// Number of compute units consumed to multiply a curve25519 edwards point
|
||||
pub curve25519_edwards_multiply_cost: u64,
|
||||
/// Number of compute units consumed for a multiscalar multiplication (msm) of edwards points.
|
||||
/// The total cost is calculated as `msm_base_cost + (length - 1) * msm_incremental_cost`.
|
||||
pub curve25519_edwards_msm_base_cost: u64,
|
||||
/// Number of compute units consumed for a multiscalar multiplication (msm) of edwards points.
|
||||
/// The total cost is calculated as `msm_base_cost + (length - 1) * msm_incremental_cost`.
|
||||
pub curve25519_edwards_msm_incremental_cost: u64,
|
||||
/// Number of compute units consumed to validate a curve25519 ristretto point
|
||||
pub curve25519_ristretto_validate_point_cost: u64,
|
||||
/// Number of compute units consumed to add two curve25519 ristretto points
|
||||
|
@ -77,6 +83,12 @@ pub struct ComputeBudget {
|
|||
pub curve25519_ristretto_subtract_cost: u64,
|
||||
/// Number of compute units consumed to multiply a curve25519 ristretto point
|
||||
pub curve25519_ristretto_multiply_cost: u64,
|
||||
/// Number of compute units consumed for a multiscalar multiplication (msm) of ristretto points.
|
||||
/// The total cost is calculated as `msm_base_cost + (length - 1) * msm_incremental_cost`.
|
||||
pub curve25519_ristretto_msm_base_cost: u64,
|
||||
/// Number of compute units consumed for a multiscalar multiplication (msm) of ristretto points.
|
||||
/// The total cost is calculated as `msm_base_cost + (length - 1) * msm_incremental_cost`.
|
||||
pub curve25519_ristretto_msm_incremental_cost: u64,
|
||||
/// Optional program heap region size, if `None` then loader default
|
||||
pub heap_size: Option<usize>,
|
||||
/// Number of compute units per additional 32k heap above the default (~.5
|
||||
|
@ -118,10 +130,14 @@ impl ComputeBudget {
|
|||
curve25519_edwards_add_cost: 331,
|
||||
curve25519_edwards_subtract_cost: 329,
|
||||
curve25519_edwards_multiply_cost: 1_753,
|
||||
curve25519_edwards_msm_base_cost: 1_870,
|
||||
curve25519_edwards_msm_incremental_cost: 670,
|
||||
curve25519_ristretto_validate_point_cost: 117,
|
||||
curve25519_ristretto_add_cost: 367,
|
||||
curve25519_ristretto_subtract_cost: 366,
|
||||
curve25519_ristretto_multiply_cost: 1_804,
|
||||
curve25519_ristretto_msm_base_cost: 1_870,
|
||||
curve25519_ristretto_msm_incremental_cost: 670,
|
||||
heap_size: None,
|
||||
heap_cost: 8,
|
||||
mem_op_base_cost: 10,
|
||||
|
|
|
@ -38,6 +38,13 @@ pub extern "C" fn entrypoint(_input: *mut u8) -> u64 {
|
|||
edwards::multiply_edwards(&scalar_one, &edwards_generator).expect("multiply_edwards")
|
||||
);
|
||||
|
||||
msg!("multiscalar_multiply_edwards");
|
||||
assert_eq!(
|
||||
edwards_generator,
|
||||
edwards::multiscalar_multiply_edwards(&[scalar_one], &[edwards_generator])
|
||||
.expect("multiscalar_multiply_edwards"),
|
||||
);
|
||||
|
||||
let ristretto_identity = ristretto::PodRistrettoPoint([
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0,
|
||||
|
@ -64,6 +71,13 @@ pub extern "C" fn entrypoint(_input: *mut u8) -> u64 {
|
|||
.expect("multiply_ristretto")
|
||||
);
|
||||
|
||||
msg!("multiscalar_multiply_ristretto");
|
||||
assert_eq!(
|
||||
ristretto_generator,
|
||||
ristretto::multiscalar_multiply_ristretto(&[scalar_one], &[ristretto_generator])
|
||||
.expect("multiscalar_multiply_ristretto"),
|
||||
);
|
||||
|
||||
0
|
||||
}
|
||||
|
||||
|
|
|
@ -208,9 +208,7 @@ pub fn register_syscalls(
|
|||
SyscallBlake3::call,
|
||||
)?;
|
||||
|
||||
// Elliptic Curve Point Validation
|
||||
//
|
||||
// TODO: add group operations and multiscalar multiplications
|
||||
// Elliptic Curve Operations
|
||||
register_feature_gated_syscall!(
|
||||
syscall_registry,
|
||||
curve25519_syscall_enabled,
|
||||
|
@ -223,6 +221,12 @@ pub fn register_syscalls(
|
|||
b"sol_curve_group_op",
|
||||
SyscallCurveGroupOps::call,
|
||||
)?;
|
||||
register_feature_gated_syscall!(
|
||||
syscall_registry,
|
||||
curve25519_syscall_enabled,
|
||||
b"sol_curve_multiscalar_mul",
|
||||
SyscallCurveMultiscalarMultiplication::call,
|
||||
)?;
|
||||
|
||||
// Sysvars
|
||||
syscall_registry
|
||||
|
@ -1144,6 +1148,111 @@ declare_syscall!(
|
|||
}
|
||||
);
|
||||
|
||||
declare_syscall!(
|
||||
// Elliptic Curve Multiscalar Multiplication
|
||||
//
|
||||
// Currently, only curve25519 Edwards and Ristretto representations are supported
|
||||
SyscallCurveMultiscalarMultiplication,
|
||||
fn inner_call(
|
||||
invoke_context: &mut InvokeContext,
|
||||
curve_id: u64,
|
||||
scalars_addr: u64,
|
||||
points_addr: u64,
|
||||
points_len: u64,
|
||||
result_point_addr: u64,
|
||||
memory_mapping: &mut MemoryMapping,
|
||||
) -> Result<u64, EbpfError> {
|
||||
use solana_zk_token_sdk::curve25519::{
|
||||
curve_syscall_traits::*, edwards, ristretto, scalar,
|
||||
};
|
||||
match curve_id {
|
||||
CURVE25519_EDWARDS => {
|
||||
let cost = invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_edwards_msm_base_cost
|
||||
.saturating_add(
|
||||
invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_edwards_msm_incremental_cost
|
||||
.saturating_mul(points_len.saturating_sub(1)),
|
||||
);
|
||||
invoke_context.get_compute_meter().consume(cost)?;
|
||||
|
||||
let scalars = translate_slice::<scalar::PodScalar>(
|
||||
memory_mapping,
|
||||
scalars_addr,
|
||||
points_len,
|
||||
invoke_context.get_check_aligned(),
|
||||
invoke_context.get_check_size(),
|
||||
)?;
|
||||
|
||||
let points = translate_slice::<edwards::PodEdwardsPoint>(
|
||||
memory_mapping,
|
||||
points_addr,
|
||||
points_len,
|
||||
invoke_context.get_check_aligned(),
|
||||
invoke_context.get_check_size(),
|
||||
)?;
|
||||
|
||||
if let Some(result_point) = edwards::multiscalar_multiply_edwards(scalars, points) {
|
||||
*translate_type_mut::<edwards::PodEdwardsPoint>(
|
||||
memory_mapping,
|
||||
result_point_addr,
|
||||
invoke_context.get_check_aligned(),
|
||||
)? = result_point;
|
||||
Ok(0)
|
||||
} else {
|
||||
Ok(1)
|
||||
}
|
||||
}
|
||||
|
||||
CURVE25519_RISTRETTO => {
|
||||
let cost = invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_ristretto_msm_base_cost
|
||||
.saturating_add(
|
||||
invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_ristretto_msm_incremental_cost
|
||||
.saturating_mul(points_len.saturating_sub(1)),
|
||||
);
|
||||
invoke_context.get_compute_meter().consume(cost)?;
|
||||
|
||||
let scalars = translate_slice::<scalar::PodScalar>(
|
||||
memory_mapping,
|
||||
scalars_addr,
|
||||
points_len,
|
||||
invoke_context.get_check_aligned(),
|
||||
invoke_context.get_check_size(),
|
||||
)?;
|
||||
|
||||
let points = translate_slice::<ristretto::PodRistrettoPoint>(
|
||||
memory_mapping,
|
||||
points_addr,
|
||||
points_len,
|
||||
invoke_context.get_check_aligned(),
|
||||
invoke_context.get_check_size(),
|
||||
)?;
|
||||
|
||||
if let Some(result_point) =
|
||||
ristretto::multiscalar_multiply_ristretto(scalars, points)
|
||||
{
|
||||
*translate_type_mut::<ristretto::PodRistrettoPoint>(
|
||||
memory_mapping,
|
||||
result_point_addr,
|
||||
invoke_context.get_check_aligned(),
|
||||
)? = result_point;
|
||||
Ok(0)
|
||||
} else {
|
||||
Ok(1)
|
||||
}
|
||||
}
|
||||
|
||||
_ => Ok(1),
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
declare_syscall!(
|
||||
// Blake3
|
||||
SyscallBlake3,
|
||||
|
@ -3149,6 +3258,149 @@ mod tests {
|
|||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_syscall_multiscalar_multiplication() {
|
||||
use solana_zk_token_sdk::curve25519::curve_syscall_traits::{
|
||||
CURVE25519_EDWARDS, CURVE25519_RISTRETTO,
|
||||
};
|
||||
|
||||
let config = Config::default();
|
||||
prepare_mockup!(
|
||||
invoke_context,
|
||||
transaction_context,
|
||||
program_id,
|
||||
bpf_loader::id(),
|
||||
);
|
||||
|
||||
let scalar_a: [u8; 32] = [
|
||||
254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250,
|
||||
78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6,
|
||||
];
|
||||
let scalar_b: [u8; 32] = [
|
||||
254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250,
|
||||
78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6,
|
||||
];
|
||||
|
||||
let scalars = [scalar_a, scalar_b];
|
||||
let scalars_va = 0x100000000;
|
||||
|
||||
let edwards_point_x: [u8; 32] = [
|
||||
252, 31, 230, 46, 173, 95, 144, 148, 158, 157, 63, 10, 8, 68, 58, 176, 142, 192, 168,
|
||||
53, 61, 105, 194, 166, 43, 56, 246, 236, 28, 146, 114, 133,
|
||||
];
|
||||
let edwards_point_y: [u8; 32] = [
|
||||
10, 111, 8, 236, 97, 189, 124, 69, 89, 176, 222, 39, 199, 253, 111, 11, 248, 186, 128,
|
||||
90, 120, 128, 248, 210, 232, 183, 93, 104, 111, 150, 7, 241,
|
||||
];
|
||||
let edwards_points = [edwards_point_x, edwards_point_y];
|
||||
let edwards_points_va = 0x200000000;
|
||||
|
||||
let ristretto_point_x: [u8; 32] = [
|
||||
130, 35, 97, 25, 18, 199, 33, 239, 85, 143, 119, 111, 49, 51, 224, 40, 167, 185, 240,
|
||||
179, 25, 194, 213, 41, 14, 155, 104, 18, 181, 197, 15, 112,
|
||||
];
|
||||
let ristretto_point_y: [u8; 32] = [
|
||||
152, 156, 155, 197, 152, 232, 92, 206, 219, 159, 193, 134, 121, 128, 139, 36, 56, 191,
|
||||
51, 143, 72, 204, 87, 76, 110, 124, 101, 96, 238, 158, 42, 108,
|
||||
];
|
||||
let ristretto_points = [ristretto_point_x, ristretto_point_y];
|
||||
let ristretto_points_va = 0x300000000;
|
||||
|
||||
let result_point: [u8; 32] = [0; 32];
|
||||
let result_point_va = 0x400000000;
|
||||
|
||||
let mut memory_mapping = MemoryMapping::new(
|
||||
vec![
|
||||
MemoryRegion {
|
||||
host_addr: scalars.as_ptr() as *const _ as u64,
|
||||
vm_addr: scalars_va,
|
||||
len: 64,
|
||||
vm_gap_shift: 63,
|
||||
is_writable: false,
|
||||
},
|
||||
MemoryRegion {
|
||||
host_addr: edwards_points.as_ptr() as *const _ as u64,
|
||||
vm_addr: edwards_points_va,
|
||||
len: 64,
|
||||
vm_gap_shift: 63,
|
||||
is_writable: false,
|
||||
},
|
||||
MemoryRegion {
|
||||
host_addr: ristretto_points.as_ptr() as *const _ as u64,
|
||||
vm_addr: ristretto_points_va,
|
||||
len: 64,
|
||||
vm_gap_shift: 63,
|
||||
is_writable: false,
|
||||
},
|
||||
MemoryRegion {
|
||||
host_addr: result_point.as_ptr() as *const _ as u64,
|
||||
vm_addr: result_point_va,
|
||||
len: 32,
|
||||
vm_gap_shift: 63,
|
||||
is_writable: true,
|
||||
},
|
||||
],
|
||||
&config,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
invoke_context
|
||||
.get_compute_meter()
|
||||
.borrow_mut()
|
||||
.mock_set_remaining(
|
||||
invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_edwards_msm_base_cost
|
||||
+ invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_edwards_msm_incremental_cost
|
||||
+ invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_ristretto_msm_base_cost
|
||||
+ invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_ristretto_msm_incremental_cost,
|
||||
);
|
||||
|
||||
let mut result = ProgramResult::Ok(0);
|
||||
SyscallCurveMultiscalarMultiplication::call(
|
||||
&mut invoke_context,
|
||||
CURVE25519_EDWARDS,
|
||||
scalars_va,
|
||||
edwards_points_va,
|
||||
2,
|
||||
result_point_va,
|
||||
&mut memory_mapping,
|
||||
&mut result,
|
||||
);
|
||||
|
||||
assert_eq!(0, result.unwrap());
|
||||
let expected_product = [
|
||||
30, 174, 168, 34, 160, 70, 63, 166, 236, 18, 74, 144, 185, 222, 208, 243, 5, 54, 223,
|
||||
172, 185, 75, 244, 26, 70, 18, 248, 46, 207, 184, 235, 60,
|
||||
];
|
||||
assert_eq!(expected_product, result_point);
|
||||
|
||||
let mut result = ProgramResult::Ok(0);
|
||||
SyscallCurveMultiscalarMultiplication::call(
|
||||
&mut invoke_context,
|
||||
CURVE25519_RISTRETTO,
|
||||
scalars_va,
|
||||
ristretto_points_va,
|
||||
2,
|
||||
result_point_va,
|
||||
&mut memory_mapping,
|
||||
&mut result,
|
||||
);
|
||||
|
||||
assert_eq!(0, result.unwrap());
|
||||
let expected_product = [
|
||||
78, 120, 86, 111, 152, 64, 146, 84, 14, 236, 77, 147, 237, 190, 251, 241, 136, 167, 21,
|
||||
94, 84, 118, 92, 140, 120, 81, 30, 246, 173, 140, 195, 86,
|
||||
];
|
||||
assert_eq!(expected_product, result_point);
|
||||
}
|
||||
|
||||
fn create_filled_type<T: Default>(zero_init: bool) -> T {
|
||||
let mut val = T::default();
|
||||
let p = &mut val as *mut _ as *mut u8;
|
||||
|
|
|
@ -62,9 +62,9 @@ define_syscall!(fn sol_log_data(data: *const u8, data_len: u64));
|
|||
define_syscall!(fn sol_get_processed_sibling_instruction(index: u64, meta: *mut ProcessedSiblingInstruction, program_id: *mut Pubkey, data: *mut u8, accounts: *mut AccountMeta) -> u64);
|
||||
define_syscall!(fn sol_get_stack_height() -> u64);
|
||||
define_syscall!(fn sol_set_account_properties(updates_addr: *const AccountPropertyUpdate, updates_count: u64));
|
||||
define_syscall!(fn sol_curve_validate_point(curve_id: u64, point: *const u8, result: *mut u8) -> u64);
|
||||
define_syscall!(fn sol_curve_group_op(curve_id: u64, op_id: u64, left_point: *const u8, right_point: *const u8, result: *mut u8) -> u64);
|
||||
define_syscall!(fn sol_curve_multiscalar_mul(curve_id: u64, scalars: *const u8, points: *const u8, result: *mut u8) -> u64);
|
||||
define_syscall!(fn sol_curve_validate_point(curve_id: u64, point_addr: *const u8, result: *mut u8) -> u64);
|
||||
define_syscall!(fn sol_curve_group_op(curve_id: u64, group_op: u64, left_input_addr: *const u8, right_input_addr: *const u8, result_point_addr: *mut u8) -> u64);
|
||||
define_syscall!(fn sol_curve_multiscalar_mul(curve_id: u64, scalars_addr: *const u8, points_addr: *const u8, points_len: u64, result_point_addr: *mut u8) -> u64);
|
||||
define_syscall!(fn sol_curve_pairing_map(curve_id: u64, point: *const u8, result: *mut u8) -> u64);
|
||||
|
||||
#[cfg(target_feature = "static-syscalls")]
|
||||
|
|
|
@ -212,6 +212,28 @@ mod target_arch {
|
|||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiscalar_multiply_edwards(
|
||||
scalars: &[PodScalar],
|
||||
points: &[PodEdwardsPoint],
|
||||
) -> Option<PodEdwardsPoint> {
|
||||
let mut result_point = PodEdwardsPoint::zeroed();
|
||||
let result = unsafe {
|
||||
solana_program::syscalls::sol_curve_multiscalar_mul(
|
||||
CURVE25519_EDWARDS,
|
||||
scalars.as_ptr() as *const u8,
|
||||
points.as_ptr() as *const u8,
|
||||
points.len() as u64,
|
||||
&mut result_point.0 as *mut u8,
|
||||
)
|
||||
};
|
||||
|
||||
if result == 0 {
|
||||
Some(result_point)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -214,6 +214,28 @@ mod target_arch {
|
|||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiscalar_multiply_ristretto(
|
||||
scalars: &[PodScalar],
|
||||
points: &[PodRistrettoPoint],
|
||||
) -> Option<PodRistrettoPoint> {
|
||||
let mut result_point = PodRistrettoPoint::zeroed();
|
||||
let result = unsafe {
|
||||
solana_program::syscalls::sol_curve_multiscalar_mul(
|
||||
CURVE25519_RISTRETTO,
|
||||
scalars.as_ptr() as *const u8,
|
||||
points.as_ptr() as *const u8,
|
||||
points.len() as u64,
|
||||
&mut result_point.0 as *mut u8,
|
||||
)
|
||||
};
|
||||
|
||||
if result == 0 {
|
||||
Some(result_point)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
Loading…
Reference in New Issue