diff --git a/program-runtime/src/compute_budget.rs b/program-runtime/src/compute_budget.rs index 42cef21966..9dc1e09672 100644 --- a/program-runtime/src/compute_budget.rs +++ b/program-runtime/src/compute_budget.rs @@ -57,10 +57,22 @@ pub struct ComputeBudget { pub syscall_base_cost: u64, /// Number of compute units consumed to call zktoken_crypto_op pub zk_token_elgamal_op_cost: u64, // to be replaced by curve25519 operations - /// Number of compute units consumed to add/sub two edwards points + /// Number of compute units consumed to validate a curve25519 edwards point pub curve25519_edwards_validate_point_cost: u64, - /// Number of compute units consumed to add/sub two ristretto points + /// Number of compute units consumed to add two curve25519 edwards points + pub curve25519_edwards_add_cost: u64, + /// Number of compute units consumed to subtract two curve25519 edwards points + pub curve25519_edwards_subtract_cost: u64, + /// Number of compute units consumed to multiply a curve25519 edwards point + pub curve25519_edwards_multiply_cost: u64, + /// Number of compute units consumed to validate a curve25519 ristretto point pub curve25519_ristretto_validate_point_cost: u64, + /// Number of compute units consumed to add two curve25519 ristretto points + pub curve25519_ristretto_add_cost: u64, + /// Number of compute units consumed to subtract two curve25519 ristretto points + pub curve25519_ristretto_subtract_cost: u64, + /// Number of compute units consumed to multiply a curve25519 ristretto point + pub curve25519_ristretto_multiply_cost: u64, /// Optional program heap region size, if `None` then loader default pub heap_size: Option, /// Number of compute units per additional 32k heap above the default (~.5 @@ -96,8 +108,14 @@ impl ComputeBudget { secp256k1_recover_cost: 25_000, syscall_base_cost: 100, zk_token_elgamal_op_cost: 25_000, - curve25519_edwards_validate_point_cost: 25_000, // TODO: precisely determine cost + curve25519_edwards_validate_point_cost: 25_000, // TODO: precisely determine curve25519 costs + curve25519_edwards_add_cost: 25_000, + curve25519_edwards_subtract_cost: 25_000, + curve25519_edwards_multiply_cost: 25_000, curve25519_ristretto_validate_point_cost: 25_000, + curve25519_ristretto_add_cost: 25_000, + curve25519_ristretto_subtract_cost: 25_000, + curve25519_ristretto_multiply_cost: 25_000, heap_size: None, heap_cost: 8, mem_op_base_cost: 10, diff --git a/programs/bpf_loader/src/syscalls.rs b/programs/bpf_loader/src/syscalls.rs index 9cfc53732b..660c1ffbf3 100644 --- a/programs/bpf_loader/src/syscalls.rs +++ b/programs/bpf_loader/src/syscalls.rs @@ -261,6 +261,13 @@ pub fn register_syscalls( SyscallCurvePointValidation::init, SyscallCurvePointValidation::call, )?; + register_feature_gated_syscall!( + syscall_registry, + curve25519_syscall_enabled, + b"sol_curve25519_point_validation", + SyscallCurveGroupOps::init, + SyscallCurveGroupOps::call, + )?; // Sysvars syscall_registry.register_syscall_by_name( @@ -1979,6 +1986,264 @@ declare_syscall!( } ); +declare_syscall!( + // Elliptic Curve Group Operations + // + // Currently, only curve25519 Edwards and Ristretto representations are supported + SyscallCurveGroupOps, + fn call( + &mut self, + curve_id: u64, + group_op: u64, + left_input_addr: u64, + right_input_addr: u64, + result_point_addr: u64, + memory_mapping: &MemoryMapping, + result: &mut Result>, + ) { + use solana_zk_token_sdk::curve25519::{ + curve_syscall_traits::*, edwards, ristretto, scalar, + }; + + let invoke_context = question_mark!( + self.invoke_context + .try_borrow() + .map_err(|_| SyscallError::InvokeContextBorrowFailed), + result + ); + + match curve_id { + CURVE25519_EDWARDS => match group_op { + ADD => { + let cost = invoke_context + .get_compute_budget() + .curve25519_edwards_add_cost; + question_mark!(invoke_context.get_compute_meter().consume(cost), result); + + let left_point = question_mark!( + translate_type::( + memory_mapping, + left_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + let right_point = question_mark!( + translate_type::( + memory_mapping, + right_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + + if let Some(result_point) = edwards::add_edwards(left_point, right_point) { + *question_mark!( + translate_type_mut::( + memory_mapping, + result_point_addr, + invoke_context.get_check_aligned(), + ), + result + ) = result_point; + *result = Ok(0); + } + } + SUB => { + let cost = invoke_context + .get_compute_budget() + .curve25519_edwards_subtract_cost; + question_mark!(invoke_context.get_compute_meter().consume(cost), result); + + let left_point = question_mark!( + translate_type::( + memory_mapping, + left_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + let right_point = question_mark!( + translate_type::( + memory_mapping, + right_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + + if let Some(result_point) = edwards::subtract_edwards(left_point, right_point) { + *question_mark!( + translate_type_mut::( + memory_mapping, + result_point_addr, + invoke_context.get_check_aligned(), + ), + result + ) = result_point; + *result = Ok(0); + } + } + MUL => { + let cost = invoke_context + .get_compute_budget() + .curve25519_edwards_multiply_cost; + question_mark!(invoke_context.get_compute_meter().consume(cost), result); + + let scalar = question_mark!( + translate_type::( + memory_mapping, + left_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + let input_point = question_mark!( + translate_type::( + memory_mapping, + right_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + + if let Some(result_point) = edwards::multiply_edwards(scalar, input_point) { + *question_mark!( + translate_type_mut::( + memory_mapping, + result_point_addr, + invoke_context.get_check_aligned(), + ), + result + ) = result_point; + *result = Ok(0); + } + } + _ => { + *result = Ok(1); + } + }, + + CURVE25519_RISTRETTO => match group_op { + ADD => { + let cost = invoke_context + .get_compute_budget() + .curve25519_ristretto_add_cost; + question_mark!(invoke_context.get_compute_meter().consume(cost), result); + + let left_point = question_mark!( + translate_type::( + memory_mapping, + left_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + let right_point = question_mark!( + translate_type::( + memory_mapping, + right_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + + if let Some(result_point) = ristretto::add_ristretto(left_point, right_point) { + *question_mark!( + translate_type_mut::( + memory_mapping, + result_point_addr, + invoke_context.get_check_aligned(), + ), + result + ) = result_point; + *result = Ok(0); + } + } + SUB => { + let cost = invoke_context + .get_compute_budget() + .curve25519_ristretto_subtract_cost; + question_mark!(invoke_context.get_compute_meter().consume(cost), result); + + let left_point = question_mark!( + translate_type::( + memory_mapping, + left_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + let right_point = question_mark!( + translate_type::( + memory_mapping, + right_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + + if let Some(result_point) = + ristretto::subtract_ristretto(left_point, right_point) + { + *question_mark!( + translate_type_mut::( + memory_mapping, + result_point_addr, + invoke_context.get_check_aligned(), + ), + result + ) = result_point; + *result = Ok(0); + } + } + MUL => { + let cost = invoke_context + .get_compute_budget() + .curve25519_ristretto_multiply_cost; + question_mark!(invoke_context.get_compute_meter().consume(cost), result); + + let scalar = question_mark!( + translate_type::( + memory_mapping, + left_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + let input_point = question_mark!( + translate_type::( + memory_mapping, + right_input_addr, + invoke_context.get_check_aligned(), + ), + result + ); + + if let Some(result_point) = ristretto::multiply_ristretto(scalar, input_point) { + *question_mark!( + translate_type_mut::( + memory_mapping, + result_point_addr, + invoke_context.get_check_aligned(), + ), + result + ) = result_point; + *result = Ok(0); + } + } + _ => { + *result = Ok(1); + } + }, + + _ => { + *result = Ok(1); + } + } + } +); + declare_syscall!( // Blake3 SyscallBlake3, diff --git a/sdk/src/feature_set.rs b/sdk/src/feature_set.rs index 3f20b8105e..fd6bd35efa 100644 --- a/sdk/src/feature_set.rs +++ b/sdk/src/feature_set.rs @@ -149,9 +149,8 @@ pub mod zk_token_sdk_enabled { solana_sdk::declare_id!("zk1snxsc6Fh3wsGNbbHAJNHiJoYgF29mMnTSusGx5EJ"); } -// TODO: temporary address for now pub mod curve25519_syscall_enabled { - solana_sdk::declare_id!("curve25519111111111111111111111111111111111"); + solana_sdk::declare_id!("7rcw5UtqgDTBBv2EcynNfYckgdAaH1MAsCjKgXMkN7Ri"); } pub mod versioned_tx_message_enabled { diff --git a/zk-token-sdk/src/curve25519/curve_syscall_traits.rs b/zk-token-sdk/src/curve25519/curve_syscall_traits.rs index 05d1eb229a..3fc8fb177d 100644 --- a/zk-token-sdk/src/curve25519/curve_syscall_traits.rs +++ b/zk-token-sdk/src/curve25519/curve_syscall_traits.rs @@ -86,9 +86,20 @@ pub const MUL: u64 = 2; extern "C" { pub fn sol_curve_validate_point(curve_id: u64, point: *const u8, result: *mut u8) -> u64; - pub fn sol_curve_op(curve_id: u64, op_id: u64, point: *const u8, result: *mut u8) -> u64; + pub fn sol_curve_op( + curve_id: u64, + op_id: u64, + left_point: *const u8, + right_point: *const u8, + result: *mut u8, + ) -> u64; - pub fn sol_curve_multiscalar_mul(curve_id: u64, point: *const u8, result: *mut u8) -> u64; + pub fn sol_curve_multiscalar_mul( + curve_id: u64, + scalars: *const u8, + points: *const u8, + result: *mut u8, + ) -> u64; pub fn sol_curve_pairing_map(curve_id: u64, point: *const u8, result: *mut u8) -> u64; } diff --git a/zk-token-sdk/src/curve25519/edwards.rs b/zk-token-sdk/src/curve25519/edwards.rs index a37761ae17..de3a9b6ffb 100644 --- a/zk-token-sdk/src/curve25519/edwards.rs +++ b/zk-token-sdk/src/curve25519/edwards.rs @@ -129,7 +129,9 @@ mod target_arch { mod target_arch { use { super::*, - crate::curve25519::curve_syscall_traits::{sol_curve_validate_point, CURVE25519_EDWARDS}, + crate::curve25519::curve_syscall_traits::{ + sol_curve_op, sol_curve_validate_point, ADD, CURVE25519_EDWARDS, MUL, SUB, + }, }; pub fn validate_edwards(point: &PodEdwardsPoint) -> bool { @@ -141,9 +143,74 @@ mod target_arch { &mut validate_result, ) }; - result == 0 } + + pub fn add_edwards( + left_point: &PodEdwardsPoint, + right_point: &PodEdwardsPoint, + ) -> Option { + let mut result_point = PodEdwardsPoint::zeroed(); + let result = unsafe { + sol_curve_op( + CURVE25519_EDWARDS, + ADD, + &left_point.0 as *const u8, + &right_point.0 as *const u8, + &mut result_point.0 as *mut u8, + ) + }; + + if result == 0 { + Some(result_point) + } else { + None + } + } + + pub fn subtract_edwards( + left_point: &PodEdwardsPoint, + right_point: &PodEdwardsPoint, + ) -> Option { + let mut result_point = PodEdwardsPoint::zeroed(); + let result = unsafe { + sol_curve_op( + CURVE25519_EDWARDS, + SUB, + &left_point.0 as *const u8, + &right_point.0 as *const u8, + &mut result_point.0 as *mut u8, + ) + }; + + if result == 0 { + Some(result_point) + } else { + None + } + } + + pub fn multiply_edwards( + left_point: &PodEdwardsPoint, + right_point: &PodEdwardsPoint, + ) -> Option { + let mut result_point = PodEdwardsPoint::zeroed(); + let result = unsafe { + sol_curve_op( + CURVE25519_EDWARDS, + MUL, + &left_point.0 as *const u8, + &right_point.0 as *const u8, + &mut result_point.0 as *mut u8, + ) + }; + + if result == 0 { + Some(result_point) + } else { + None + } + } } #[cfg(test)] diff --git a/zk-token-sdk/src/curve25519/ristretto.rs b/zk-token-sdk/src/curve25519/ristretto.rs index 5542a756b7..03780a454a 100644 --- a/zk-token-sdk/src/curve25519/ristretto.rs +++ b/zk-token-sdk/src/curve25519/ristretto.rs @@ -130,7 +130,9 @@ mod target_arch { mod target_arch { use { super::*, - crate::curve25519::curve_syscall_traits::{sol_curve_validate_point, CURVE25519_RISTRETTO}, + crate::curve25519::curve_syscall_traits::{ + sol_curve_op, sol_curve_validate_point, ADD, CURVE25519_RISTRETTO, MUL, SUB, + }, }; pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool { @@ -145,6 +147,72 @@ mod target_arch { result == 0 } + + pub fn add_ristretto( + left_point: &PodRistrettoPoint, + right_point: &PodRistrettoPoint, + ) -> Option { + let mut result_point = PodRistrettoPoint::zeroed(); + let result = unsafe { + sol_curve_op( + CURVE25519_RISTRETTO, + ADD, + &left_point.0 as *const u8, + &right_point.0 as *const u8, + &mut result_point.0 as *mut u8, + ) + }; + + if result == 0 { + Some(result_point) + } else { + None + } + } + + pub fn subtract_ristretto( + left_point: &PodRistrettoPoint, + right_point: &PodRistrettoPoint, + ) -> Option { + let mut result_point = PodRistrettoPoint::zeroed(); + let result = unsafe { + sol_curve_op( + CURVE25519_RISTRETTO, + SUB, + &left_point.0 as *const u8, + &right_point.0 as *const u8, + &mut result_point.0 as *mut u8, + ) + }; + + if result == 0 { + Some(result_point) + } else { + None + } + } + + pub fn multiply_ristretto( + left_point: &PodRistrettoPoint, + right_point: &PodRistrettoPoint, + ) -> Option { + let mut result_point = PodRistrettoPoint::zeroed(); + let result = unsafe { + sol_curve_op( + CURVE25519_RISTRETTO, + MUL, + &left_point.0 as *const u8, + &right_point.0 as *const u8, + &mut result_point.0 as *mut u8, + ) + }; + + if result == 0 { + Some(result_point) + } else { + None + } + } } #[cfg(test)]