add curve25519 multiscalar multiplication syscall (#28216)

* add curve25519 multiscalar multiplication syscall

* update compute unit costs

* update tests

* add update to compute budget

* add syscall call function

* update compute costs in tests

* update syscall syntax
This commit is contained in:
samkim-crypto 2022-10-12 14:43:02 +09:00 committed by GitHub
parent 061bed0a8c
commit 3f63283eda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 332 additions and 6 deletions

View File

@ -69,6 +69,12 @@ pub struct ComputeBudget {
pub curve25519_edwards_subtract_cost: u64,
/// Number of compute units consumed to multiply a curve25519 edwards point
pub curve25519_edwards_multiply_cost: u64,
/// Number of compute units consumed for a multiscalar multiplication (msm) of edwards points.
/// The total cost is calculated as `msm_base_cost + (length - 1) * msm_incremental_cost`.
pub curve25519_edwards_msm_base_cost: u64,
/// Number of compute units consumed for a multiscalar multiplication (msm) of edwards points.
/// The total cost is calculated as `msm_base_cost + (length - 1) * msm_incremental_cost`.
pub curve25519_edwards_msm_incremental_cost: u64,
/// Number of compute units consumed to validate a curve25519 ristretto point
pub curve25519_ristretto_validate_point_cost: u64,
/// Number of compute units consumed to add two curve25519 ristretto points
@ -77,6 +83,12 @@ pub struct ComputeBudget {
pub curve25519_ristretto_subtract_cost: u64,
/// Number of compute units consumed to multiply a curve25519 ristretto point
pub curve25519_ristretto_multiply_cost: u64,
/// Number of compute units consumed for a multiscalar multiplication (msm) of ristretto points.
/// The total cost is calculated as `msm_base_cost + (length - 1) * msm_incremental_cost`.
pub curve25519_ristretto_msm_base_cost: u64,
/// Number of compute units consumed for a multiscalar multiplication (msm) of ristretto points.
/// The total cost is calculated as `msm_base_cost + (length - 1) * msm_incremental_cost`.
pub curve25519_ristretto_msm_incremental_cost: u64,
/// Optional program heap region size, if `None` then loader default
pub heap_size: Option<usize>,
/// Number of compute units per additional 32k heap above the default (~.5
@ -118,10 +130,14 @@ impl ComputeBudget {
curve25519_edwards_add_cost: 331,
curve25519_edwards_subtract_cost: 329,
curve25519_edwards_multiply_cost: 1_753,
curve25519_edwards_msm_base_cost: 1_870,
curve25519_edwards_msm_incremental_cost: 670,
curve25519_ristretto_validate_point_cost: 117,
curve25519_ristretto_add_cost: 367,
curve25519_ristretto_subtract_cost: 366,
curve25519_ristretto_multiply_cost: 1_804,
curve25519_ristretto_msm_base_cost: 1_870,
curve25519_ristretto_msm_incremental_cost: 670,
heap_size: None,
heap_cost: 8,
mem_op_base_cost: 10,

View File

@ -38,6 +38,13 @@ pub extern "C" fn entrypoint(_input: *mut u8) -> u64 {
edwards::multiply_edwards(&scalar_one, &edwards_generator).expect("multiply_edwards")
);
msg!("multiscalar_multiply_edwards");
assert_eq!(
edwards_generator,
edwards::multiscalar_multiply_edwards(&[scalar_one], &[edwards_generator])
.expect("multiscalar_multiply_edwards"),
);
let ristretto_identity = ristretto::PodRistrettoPoint([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0,
@ -64,6 +71,13 @@ pub extern "C" fn entrypoint(_input: *mut u8) -> u64 {
.expect("multiply_ristretto")
);
msg!("multiscalar_multiply_ristretto");
assert_eq!(
ristretto_generator,
ristretto::multiscalar_multiply_ristretto(&[scalar_one], &[ristretto_generator])
.expect("multiscalar_multiply_ristretto"),
);
0
}

View File

@ -208,9 +208,7 @@ pub fn register_syscalls(
SyscallBlake3::call,
)?;
// Elliptic Curve Point Validation
//
// TODO: add group operations and multiscalar multiplications
// Elliptic Curve Operations
register_feature_gated_syscall!(
syscall_registry,
curve25519_syscall_enabled,
@ -223,6 +221,12 @@ pub fn register_syscalls(
b"sol_curve_group_op",
SyscallCurveGroupOps::call,
)?;
register_feature_gated_syscall!(
syscall_registry,
curve25519_syscall_enabled,
b"sol_curve_multiscalar_mul",
SyscallCurveMultiscalarMultiplication::call,
)?;
// Sysvars
syscall_registry
@ -1144,6 +1148,111 @@ declare_syscall!(
}
);
declare_syscall!(
// Elliptic Curve Multiscalar Multiplication
//
// Currently, only curve25519 Edwards and Ristretto representations are supported
SyscallCurveMultiscalarMultiplication,
fn inner_call(
invoke_context: &mut InvokeContext,
curve_id: u64,
scalars_addr: u64,
points_addr: u64,
points_len: u64,
result_point_addr: u64,
memory_mapping: &mut MemoryMapping,
) -> Result<u64, EbpfError> {
use solana_zk_token_sdk::curve25519::{
curve_syscall_traits::*, edwards, ristretto, scalar,
};
match curve_id {
CURVE25519_EDWARDS => {
let cost = invoke_context
.get_compute_budget()
.curve25519_edwards_msm_base_cost
.saturating_add(
invoke_context
.get_compute_budget()
.curve25519_edwards_msm_incremental_cost
.saturating_mul(points_len.saturating_sub(1)),
);
invoke_context.get_compute_meter().consume(cost)?;
let scalars = translate_slice::<scalar::PodScalar>(
memory_mapping,
scalars_addr,
points_len,
invoke_context.get_check_aligned(),
invoke_context.get_check_size(),
)?;
let points = translate_slice::<edwards::PodEdwardsPoint>(
memory_mapping,
points_addr,
points_len,
invoke_context.get_check_aligned(),
invoke_context.get_check_size(),
)?;
if let Some(result_point) = edwards::multiscalar_multiply_edwards(scalars, points) {
*translate_type_mut::<edwards::PodEdwardsPoint>(
memory_mapping,
result_point_addr,
invoke_context.get_check_aligned(),
)? = result_point;
Ok(0)
} else {
Ok(1)
}
}
CURVE25519_RISTRETTO => {
let cost = invoke_context
.get_compute_budget()
.curve25519_ristretto_msm_base_cost
.saturating_add(
invoke_context
.get_compute_budget()
.curve25519_ristretto_msm_incremental_cost
.saturating_mul(points_len.saturating_sub(1)),
);
invoke_context.get_compute_meter().consume(cost)?;
let scalars = translate_slice::<scalar::PodScalar>(
memory_mapping,
scalars_addr,
points_len,
invoke_context.get_check_aligned(),
invoke_context.get_check_size(),
)?;
let points = translate_slice::<ristretto::PodRistrettoPoint>(
memory_mapping,
points_addr,
points_len,
invoke_context.get_check_aligned(),
invoke_context.get_check_size(),
)?;
if let Some(result_point) =
ristretto::multiscalar_multiply_ristretto(scalars, points)
{
*translate_type_mut::<ristretto::PodRistrettoPoint>(
memory_mapping,
result_point_addr,
invoke_context.get_check_aligned(),
)? = result_point;
Ok(0)
} else {
Ok(1)
}
}
_ => Ok(1),
}
}
);
declare_syscall!(
// Blake3
SyscallBlake3,
@ -3149,6 +3258,149 @@ mod tests {
));
}
#[test]
fn test_syscall_multiscalar_multiplication() {
use solana_zk_token_sdk::curve25519::curve_syscall_traits::{
CURVE25519_EDWARDS, CURVE25519_RISTRETTO,
};
let config = Config::default();
prepare_mockup!(
invoke_context,
transaction_context,
program_id,
bpf_loader::id(),
);
let scalar_a: [u8; 32] = [
254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250,
78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6,
];
let scalar_b: [u8; 32] = [
254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250,
78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6,
];
let scalars = [scalar_a, scalar_b];
let scalars_va = 0x100000000;
let edwards_point_x: [u8; 32] = [
252, 31, 230, 46, 173, 95, 144, 148, 158, 157, 63, 10, 8, 68, 58, 176, 142, 192, 168,
53, 61, 105, 194, 166, 43, 56, 246, 236, 28, 146, 114, 133,
];
let edwards_point_y: [u8; 32] = [
10, 111, 8, 236, 97, 189, 124, 69, 89, 176, 222, 39, 199, 253, 111, 11, 248, 186, 128,
90, 120, 128, 248, 210, 232, 183, 93, 104, 111, 150, 7, 241,
];
let edwards_points = [edwards_point_x, edwards_point_y];
let edwards_points_va = 0x200000000;
let ristretto_point_x: [u8; 32] = [
130, 35, 97, 25, 18, 199, 33, 239, 85, 143, 119, 111, 49, 51, 224, 40, 167, 185, 240,
179, 25, 194, 213, 41, 14, 155, 104, 18, 181, 197, 15, 112,
];
let ristretto_point_y: [u8; 32] = [
152, 156, 155, 197, 152, 232, 92, 206, 219, 159, 193, 134, 121, 128, 139, 36, 56, 191,
51, 143, 72, 204, 87, 76, 110, 124, 101, 96, 238, 158, 42, 108,
];
let ristretto_points = [ristretto_point_x, ristretto_point_y];
let ristretto_points_va = 0x300000000;
let result_point: [u8; 32] = [0; 32];
let result_point_va = 0x400000000;
let mut memory_mapping = MemoryMapping::new(
vec![
MemoryRegion {
host_addr: scalars.as_ptr() as *const _ as u64,
vm_addr: scalars_va,
len: 64,
vm_gap_shift: 63,
is_writable: false,
},
MemoryRegion {
host_addr: edwards_points.as_ptr() as *const _ as u64,
vm_addr: edwards_points_va,
len: 64,
vm_gap_shift: 63,
is_writable: false,
},
MemoryRegion {
host_addr: ristretto_points.as_ptr() as *const _ as u64,
vm_addr: ristretto_points_va,
len: 64,
vm_gap_shift: 63,
is_writable: false,
},
MemoryRegion {
host_addr: result_point.as_ptr() as *const _ as u64,
vm_addr: result_point_va,
len: 32,
vm_gap_shift: 63,
is_writable: true,
},
],
&config,
)
.unwrap();
invoke_context
.get_compute_meter()
.borrow_mut()
.mock_set_remaining(
invoke_context
.get_compute_budget()
.curve25519_edwards_msm_base_cost
+ invoke_context
.get_compute_budget()
.curve25519_edwards_msm_incremental_cost
+ invoke_context
.get_compute_budget()
.curve25519_ristretto_msm_base_cost
+ invoke_context
.get_compute_budget()
.curve25519_ristretto_msm_incremental_cost,
);
let mut result = ProgramResult::Ok(0);
SyscallCurveMultiscalarMultiplication::call(
&mut invoke_context,
CURVE25519_EDWARDS,
scalars_va,
edwards_points_va,
2,
result_point_va,
&mut memory_mapping,
&mut result,
);
assert_eq!(0, result.unwrap());
let expected_product = [
30, 174, 168, 34, 160, 70, 63, 166, 236, 18, 74, 144, 185, 222, 208, 243, 5, 54, 223,
172, 185, 75, 244, 26, 70, 18, 248, 46, 207, 184, 235, 60,
];
assert_eq!(expected_product, result_point);
let mut result = ProgramResult::Ok(0);
SyscallCurveMultiscalarMultiplication::call(
&mut invoke_context,
CURVE25519_RISTRETTO,
scalars_va,
ristretto_points_va,
2,
result_point_va,
&mut memory_mapping,
&mut result,
);
assert_eq!(0, result.unwrap());
let expected_product = [
78, 120, 86, 111, 152, 64, 146, 84, 14, 236, 77, 147, 237, 190, 251, 241, 136, 167, 21,
94, 84, 118, 92, 140, 120, 81, 30, 246, 173, 140, 195, 86,
];
assert_eq!(expected_product, result_point);
}
fn create_filled_type<T: Default>(zero_init: bool) -> T {
let mut val = T::default();
let p = &mut val as *mut _ as *mut u8;

View File

@ -62,9 +62,9 @@ define_syscall!(fn sol_log_data(data: *const u8, data_len: u64));
define_syscall!(fn sol_get_processed_sibling_instruction(index: u64, meta: *mut ProcessedSiblingInstruction, program_id: *mut Pubkey, data: *mut u8, accounts: *mut AccountMeta) -> u64);
define_syscall!(fn sol_get_stack_height() -> u64);
define_syscall!(fn sol_set_account_properties(updates_addr: *const AccountPropertyUpdate, updates_count: u64));
define_syscall!(fn sol_curve_validate_point(curve_id: u64, point: *const u8, result: *mut u8) -> u64);
define_syscall!(fn sol_curve_group_op(curve_id: u64, op_id: u64, left_point: *const u8, right_point: *const u8, result: *mut u8) -> u64);
define_syscall!(fn sol_curve_multiscalar_mul(curve_id: u64, scalars: *const u8, points: *const u8, result: *mut u8) -> u64);
define_syscall!(fn sol_curve_validate_point(curve_id: u64, point_addr: *const u8, result: *mut u8) -> u64);
define_syscall!(fn sol_curve_group_op(curve_id: u64, group_op: u64, left_input_addr: *const u8, right_input_addr: *const u8, result_point_addr: *mut u8) -> u64);
define_syscall!(fn sol_curve_multiscalar_mul(curve_id: u64, scalars_addr: *const u8, points_addr: *const u8, points_len: u64, result_point_addr: *mut u8) -> u64);
define_syscall!(fn sol_curve_pairing_map(curve_id: u64, point: *const u8, result: *mut u8) -> u64);
#[cfg(target_feature = "static-syscalls")]

View File

@ -212,6 +212,28 @@ mod target_arch {
None
}
}
pub fn multiscalar_multiply_edwards(
scalars: &[PodScalar],
points: &[PodEdwardsPoint],
) -> Option<PodEdwardsPoint> {
let mut result_point = PodEdwardsPoint::zeroed();
let result = unsafe {
solana_program::syscalls::sol_curve_multiscalar_mul(
CURVE25519_EDWARDS,
scalars.as_ptr() as *const u8,
points.as_ptr() as *const u8,
points.len() as u64,
&mut result_point.0 as *mut u8,
)
};
if result == 0 {
Some(result_point)
} else {
None
}
}
}
#[cfg(test)]

View File

@ -214,6 +214,28 @@ mod target_arch {
None
}
}
pub fn multiscalar_multiply_ristretto(
scalars: &[PodScalar],
points: &[PodRistrettoPoint],
) -> Option<PodRistrettoPoint> {
let mut result_point = PodRistrettoPoint::zeroed();
let result = unsafe {
solana_program::syscalls::sol_curve_multiscalar_mul(
CURVE25519_RISTRETTO,
scalars.as_ptr() as *const u8,
points.as_ptr() as *const u8,
points.len() as u64,
&mut result_point.0 as *mut u8,
)
};
if result == 0 {
Some(result_point)
} else {
None
}
}
}
#[cfg(test)]