Curve25519 syscall group ops (#25071)

* zk-token-sdk: implement group ops trait for curve25519

* zk-token-sdk: extend syscall trait implementation for group ops for ristretto

* zk-token-sdk: register curve25519 group ops to bpf loader

* zk-token-sdk: update curve25519_syscall_enabled address
This commit is contained in:
samkim-crypto 2022-05-08 11:28:07 +09:00 committed by GitHub
parent e6c02f30dd
commit aba6a89517
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 438 additions and 10 deletions

View File

@ -57,10 +57,22 @@ pub struct ComputeBudget {
pub syscall_base_cost: u64,
/// Number of compute units consumed to call zktoken_crypto_op
pub zk_token_elgamal_op_cost: u64, // to be replaced by curve25519 operations
/// Number of compute units consumed to add/sub two edwards points
/// Number of compute units consumed to validate a curve25519 edwards point
pub curve25519_edwards_validate_point_cost: u64,
/// Number of compute units consumed to add/sub two ristretto points
/// Number of compute units consumed to add two curve25519 edwards points
pub curve25519_edwards_add_cost: u64,
/// Number of compute units consumed to subtract two curve25519 edwards points
pub curve25519_edwards_subtract_cost: u64,
/// Number of compute units consumed to multiply a curve25519 edwards point
pub curve25519_edwards_multiply_cost: u64,
/// Number of compute units consumed to validate a curve25519 ristretto point
pub curve25519_ristretto_validate_point_cost: u64,
/// Number of compute units consumed to add two curve25519 ristretto points
pub curve25519_ristretto_add_cost: u64,
/// Number of compute units consumed to subtract two curve25519 ristretto points
pub curve25519_ristretto_subtract_cost: u64,
/// Number of compute units consumed to multiply a curve25519 ristretto point
pub curve25519_ristretto_multiply_cost: u64,
/// Optional program heap region size, if `None` then loader default
pub heap_size: Option<usize>,
/// Number of compute units per additional 32k heap above the default (~.5
@ -96,8 +108,14 @@ impl ComputeBudget {
secp256k1_recover_cost: 25_000,
syscall_base_cost: 100,
zk_token_elgamal_op_cost: 25_000,
curve25519_edwards_validate_point_cost: 25_000, // TODO: precisely determine cost
curve25519_edwards_validate_point_cost: 25_000, // TODO: precisely determine curve25519 costs
curve25519_edwards_add_cost: 25_000,
curve25519_edwards_subtract_cost: 25_000,
curve25519_edwards_multiply_cost: 25_000,
curve25519_ristretto_validate_point_cost: 25_000,
curve25519_ristretto_add_cost: 25_000,
curve25519_ristretto_subtract_cost: 25_000,
curve25519_ristretto_multiply_cost: 25_000,
heap_size: None,
heap_cost: 8,
mem_op_base_cost: 10,

View File

@ -261,6 +261,13 @@ pub fn register_syscalls(
SyscallCurvePointValidation::init,
SyscallCurvePointValidation::call,
)?;
register_feature_gated_syscall!(
syscall_registry,
curve25519_syscall_enabled,
b"sol_curve25519_point_validation",
SyscallCurveGroupOps::init,
SyscallCurveGroupOps::call,
)?;
// Sysvars
syscall_registry.register_syscall_by_name(
@ -1979,6 +1986,264 @@ declare_syscall!(
}
);
declare_syscall!(
// Elliptic Curve Group Operations
//
// Currently, only curve25519 Edwards and Ristretto representations are supported
SyscallCurveGroupOps,
fn call(
&mut self,
curve_id: u64,
group_op: u64,
left_input_addr: u64,
right_input_addr: u64,
result_point_addr: u64,
memory_mapping: &MemoryMapping,
result: &mut Result<u64, EbpfError<BpfError>>,
) {
use solana_zk_token_sdk::curve25519::{
curve_syscall_traits::*, edwards, ristretto, scalar,
};
let invoke_context = question_mark!(
self.invoke_context
.try_borrow()
.map_err(|_| SyscallError::InvokeContextBorrowFailed),
result
);
match curve_id {
CURVE25519_EDWARDS => match group_op {
ADD => {
let cost = invoke_context
.get_compute_budget()
.curve25519_edwards_add_cost;
question_mark!(invoke_context.get_compute_meter().consume(cost), result);
let left_point = question_mark!(
translate_type::<edwards::PodEdwardsPoint>(
memory_mapping,
left_input_addr,
invoke_context.get_check_aligned(),
),
result
);
let right_point = question_mark!(
translate_type::<edwards::PodEdwardsPoint>(
memory_mapping,
right_input_addr,
invoke_context.get_check_aligned(),
),
result
);
if let Some(result_point) = edwards::add_edwards(left_point, right_point) {
*question_mark!(
translate_type_mut::<edwards::PodEdwardsPoint>(
memory_mapping,
result_point_addr,
invoke_context.get_check_aligned(),
),
result
) = result_point;
*result = Ok(0);
}
}
SUB => {
let cost = invoke_context
.get_compute_budget()
.curve25519_edwards_subtract_cost;
question_mark!(invoke_context.get_compute_meter().consume(cost), result);
let left_point = question_mark!(
translate_type::<edwards::PodEdwardsPoint>(
memory_mapping,
left_input_addr,
invoke_context.get_check_aligned(),
),
result
);
let right_point = question_mark!(
translate_type::<edwards::PodEdwardsPoint>(
memory_mapping,
right_input_addr,
invoke_context.get_check_aligned(),
),
result
);
if let Some(result_point) = edwards::subtract_edwards(left_point, right_point) {
*question_mark!(
translate_type_mut::<edwards::PodEdwardsPoint>(
memory_mapping,
result_point_addr,
invoke_context.get_check_aligned(),
),
result
) = result_point;
*result = Ok(0);
}
}
MUL => {
let cost = invoke_context
.get_compute_budget()
.curve25519_edwards_multiply_cost;
question_mark!(invoke_context.get_compute_meter().consume(cost), result);
let scalar = question_mark!(
translate_type::<scalar::PodScalar>(
memory_mapping,
left_input_addr,
invoke_context.get_check_aligned(),
),
result
);
let input_point = question_mark!(
translate_type::<edwards::PodEdwardsPoint>(
memory_mapping,
right_input_addr,
invoke_context.get_check_aligned(),
),
result
);
if let Some(result_point) = edwards::multiply_edwards(scalar, input_point) {
*question_mark!(
translate_type_mut::<edwards::PodEdwardsPoint>(
memory_mapping,
result_point_addr,
invoke_context.get_check_aligned(),
),
result
) = result_point;
*result = Ok(0);
}
}
_ => {
*result = Ok(1);
}
},
CURVE25519_RISTRETTO => match group_op {
ADD => {
let cost = invoke_context
.get_compute_budget()
.curve25519_ristretto_add_cost;
question_mark!(invoke_context.get_compute_meter().consume(cost), result);
let left_point = question_mark!(
translate_type::<ristretto::PodRistrettoPoint>(
memory_mapping,
left_input_addr,
invoke_context.get_check_aligned(),
),
result
);
let right_point = question_mark!(
translate_type::<ristretto::PodRistrettoPoint>(
memory_mapping,
right_input_addr,
invoke_context.get_check_aligned(),
),
result
);
if let Some(result_point) = ristretto::add_ristretto(left_point, right_point) {
*question_mark!(
translate_type_mut::<ristretto::PodRistrettoPoint>(
memory_mapping,
result_point_addr,
invoke_context.get_check_aligned(),
),
result
) = result_point;
*result = Ok(0);
}
}
SUB => {
let cost = invoke_context
.get_compute_budget()
.curve25519_ristretto_subtract_cost;
question_mark!(invoke_context.get_compute_meter().consume(cost), result);
let left_point = question_mark!(
translate_type::<ristretto::PodRistrettoPoint>(
memory_mapping,
left_input_addr,
invoke_context.get_check_aligned(),
),
result
);
let right_point = question_mark!(
translate_type::<ristretto::PodRistrettoPoint>(
memory_mapping,
right_input_addr,
invoke_context.get_check_aligned(),
),
result
);
if let Some(result_point) =
ristretto::subtract_ristretto(left_point, right_point)
{
*question_mark!(
translate_type_mut::<ristretto::PodRistrettoPoint>(
memory_mapping,
result_point_addr,
invoke_context.get_check_aligned(),
),
result
) = result_point;
*result = Ok(0);
}
}
MUL => {
let cost = invoke_context
.get_compute_budget()
.curve25519_ristretto_multiply_cost;
question_mark!(invoke_context.get_compute_meter().consume(cost), result);
let scalar = question_mark!(
translate_type::<scalar::PodScalar>(
memory_mapping,
left_input_addr,
invoke_context.get_check_aligned(),
),
result
);
let input_point = question_mark!(
translate_type::<ristretto::PodRistrettoPoint>(
memory_mapping,
right_input_addr,
invoke_context.get_check_aligned(),
),
result
);
if let Some(result_point) = ristretto::multiply_ristretto(scalar, input_point) {
*question_mark!(
translate_type_mut::<ristretto::PodRistrettoPoint>(
memory_mapping,
result_point_addr,
invoke_context.get_check_aligned(),
),
result
) = result_point;
*result = Ok(0);
}
}
_ => {
*result = Ok(1);
}
},
_ => {
*result = Ok(1);
}
}
}
);
declare_syscall!(
// Blake3
SyscallBlake3,

View File

@ -149,9 +149,8 @@ pub mod zk_token_sdk_enabled {
solana_sdk::declare_id!("zk1snxsc6Fh3wsGNbbHAJNHiJoYgF29mMnTSusGx5EJ");
}
// TODO: temporary address for now
pub mod curve25519_syscall_enabled {
solana_sdk::declare_id!("curve25519111111111111111111111111111111111");
solana_sdk::declare_id!("7rcw5UtqgDTBBv2EcynNfYckgdAaH1MAsCjKgXMkN7Ri");
}
pub mod versioned_tx_message_enabled {

View File

@ -86,9 +86,20 @@ pub const MUL: u64 = 2;
extern "C" {
pub fn sol_curve_validate_point(curve_id: u64, point: *const u8, result: *mut u8) -> u64;
pub fn sol_curve_op(curve_id: u64, op_id: u64, point: *const u8, result: *mut u8) -> u64;
pub fn sol_curve_op(
curve_id: u64,
op_id: u64,
left_point: *const u8,
right_point: *const u8,
result: *mut u8,
) -> u64;
pub fn sol_curve_multiscalar_mul(curve_id: u64, point: *const u8, result: *mut u8) -> u64;
pub fn sol_curve_multiscalar_mul(
curve_id: u64,
scalars: *const u8,
points: *const u8,
result: *mut u8,
) -> u64;
pub fn sol_curve_pairing_map(curve_id: u64, point: *const u8, result: *mut u8) -> u64;
}

View File

@ -129,7 +129,9 @@ mod target_arch {
mod target_arch {
use {
super::*,
crate::curve25519::curve_syscall_traits::{sol_curve_validate_point, CURVE25519_EDWARDS},
crate::curve25519::curve_syscall_traits::{
sol_curve_op, sol_curve_validate_point, ADD, CURVE25519_EDWARDS, MUL, SUB,
},
};
pub fn validate_edwards(point: &PodEdwardsPoint) -> bool {
@ -141,9 +143,74 @@ mod target_arch {
&mut validate_result,
)
};
result == 0
}
pub fn add_edwards(
left_point: &PodEdwardsPoint,
right_point: &PodEdwardsPoint,
) -> Option<PodEdwardsPoint> {
let mut result_point = PodEdwardsPoint::zeroed();
let result = unsafe {
sol_curve_op(
CURVE25519_EDWARDS,
ADD,
&left_point.0 as *const u8,
&right_point.0 as *const u8,
&mut result_point.0 as *mut u8,
)
};
if result == 0 {
Some(result_point)
} else {
None
}
}
pub fn subtract_edwards(
left_point: &PodEdwardsPoint,
right_point: &PodEdwardsPoint,
) -> Option<PodEdwardsPoint> {
let mut result_point = PodEdwardsPoint::zeroed();
let result = unsafe {
sol_curve_op(
CURVE25519_EDWARDS,
SUB,
&left_point.0 as *const u8,
&right_point.0 as *const u8,
&mut result_point.0 as *mut u8,
)
};
if result == 0 {
Some(result_point)
} else {
None
}
}
pub fn multiply_edwards(
left_point: &PodEdwardsPoint,
right_point: &PodEdwardsPoint,
) -> Option<PodEdwardsPoint> {
let mut result_point = PodEdwardsPoint::zeroed();
let result = unsafe {
sol_curve_op(
CURVE25519_EDWARDS,
MUL,
&left_point.0 as *const u8,
&right_point.0 as *const u8,
&mut result_point.0 as *mut u8,
)
};
if result == 0 {
Some(result_point)
} else {
None
}
}
}
#[cfg(test)]

View File

@ -130,7 +130,9 @@ mod target_arch {
mod target_arch {
use {
super::*,
crate::curve25519::curve_syscall_traits::{sol_curve_validate_point, CURVE25519_RISTRETTO},
crate::curve25519::curve_syscall_traits::{
sol_curve_op, sol_curve_validate_point, ADD, CURVE25519_RISTRETTO, MUL, SUB,
},
};
pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool {
@ -145,6 +147,72 @@ mod target_arch {
result == 0
}
pub fn add_ristretto(
left_point: &PodRistrettoPoint,
right_point: &PodRistrettoPoint,
) -> Option<PodRistrettoPoint> {
let mut result_point = PodRistrettoPoint::zeroed();
let result = unsafe {
sol_curve_op(
CURVE25519_RISTRETTO,
ADD,
&left_point.0 as *const u8,
&right_point.0 as *const u8,
&mut result_point.0 as *mut u8,
)
};
if result == 0 {
Some(result_point)
} else {
None
}
}
pub fn subtract_ristretto(
left_point: &PodRistrettoPoint,
right_point: &PodRistrettoPoint,
) -> Option<PodRistrettoPoint> {
let mut result_point = PodRistrettoPoint::zeroed();
let result = unsafe {
sol_curve_op(
CURVE25519_RISTRETTO,
SUB,
&left_point.0 as *const u8,
&right_point.0 as *const u8,
&mut result_point.0 as *mut u8,
)
};
if result == 0 {
Some(result_point)
} else {
None
}
}
pub fn multiply_ristretto(
left_point: &PodRistrettoPoint,
right_point: &PodRistrettoPoint,
) -> Option<PodRistrettoPoint> {
let mut result_point = PodRistrettoPoint::zeroed();
let result = unsafe {
sol_curve_op(
CURVE25519_RISTRETTO,
MUL,
&left_point.0 as *const u8,
&right_point.0 as *const u8,
&mut result_point.0 as *mut u8,
)
};
if result == 0 {
Some(result_point)
} else {
None
}
}
}
#[cfg(test)]