Curve25519 point validation syscall (#23771)
* zk-token-sdk: add curve25519 basic ops * zk-token-sdk: add bpf operations for curve25519 ops * zk-token-sdk: rebase * zk-token-sdk: add tests for curve25519 opertions * zk-token-sdk: rustfmt * zk-token-sdk: organize syscalls by trait * zk-token-sdk: organize syscalls by trait * zk-token-sdk: cleaning up * zk-token-sdk: rename mods * zk-token-sdk: cargo fmt * zk-token-sdk: fix tests for edwards and ristretto * zk-token-sdk: add Syscall object for curve point validation * zk-token-sdk: docs for curve syscall traits * zk-token-sdk: fix errors from rebase * zk-token-sdk: update Vec to slice Co-authored-by: Trent Nelson <trent.a.b.nelson@gmail.com> * zk-token-sdk: use enum with num-derive for curve ids * zk-token-sdk: update vec to slice * zk-token-sdk: make curve25519 tests be deterministic * zk-token-sdk: rebase * token-2022: re-organizing curve point validation * token-2022: cargo fmt * zk-token-sdk: minor Co-authored-by: Trent Nelson <trent.a.b.nelson@gmail.com>
This commit is contained in:
parent
c785f1ffc5
commit
d9deab4d2c
|
@ -56,7 +56,11 @@ pub struct ComputeBudget {
|
|||
/// Number of compute units consumed to do a syscall without any work
|
||||
pub syscall_base_cost: u64,
|
||||
/// Number of compute units consumed to call zktoken_crypto_op
|
||||
pub zk_token_elgamal_op_cost: u64,
|
||||
pub zk_token_elgamal_op_cost: u64, // to be replaced by curve25519 operations
|
||||
/// Number of compute units consumed to add/sub two edwards points
|
||||
pub curve25519_edwards_validate_point_cost: u64,
|
||||
/// Number of compute units consumed to add/sub two ristretto points
|
||||
pub curve25519_ristretto_validate_point_cost: u64,
|
||||
/// Optional program heap region size, if `None` then loader default
|
||||
pub heap_size: Option<usize>,
|
||||
/// Number of compute units per additional 32k heap above the default (~.5
|
||||
|
@ -92,6 +96,8 @@ impl ComputeBudget {
|
|||
secp256k1_recover_cost: 25_000,
|
||||
syscall_base_cost: 100,
|
||||
zk_token_elgamal_op_cost: 25_000,
|
||||
curve25519_edwards_validate_point_cost: 25_000, // TODO: precisely determine cost
|
||||
curve25519_ristretto_validate_point_cost: 25_000,
|
||||
heap_size: None,
|
||||
heap_cost: 8,
|
||||
mem_op_base_cost: 10,
|
||||
|
|
|
@ -22,12 +22,13 @@ use {
|
|||
entrypoint::{BPF_ALIGN_OF_U128, MAX_PERMITTED_DATA_INCREASE, SUCCESS},
|
||||
feature_set::{
|
||||
add_get_processed_sibling_instruction_syscall, blake3_syscall_enabled,
|
||||
check_physical_overlapping, check_slice_translation_size, disable_fees_sysvar,
|
||||
do_support_realloc, executables_incur_cpi_data_cost, fixed_memcpy_nonoverlapping_check,
|
||||
libsecp256k1_0_5_upgrade_enabled, limit_secp256k1_recovery_id,
|
||||
prevent_calling_precompiles_as_programs, return_data_syscall_enabled,
|
||||
secp256k1_recover_syscall_enabled, sol_log_data_syscall_enabled,
|
||||
syscall_saturated_math, update_syscall_base_costs, zk_token_sdk_enabled,
|
||||
check_physical_overlapping, check_slice_translation_size, curve25519_syscall_enabled,
|
||||
disable_fees_sysvar, do_support_realloc, executables_incur_cpi_data_cost,
|
||||
fixed_memcpy_nonoverlapping_check, libsecp256k1_0_5_upgrade_enabled,
|
||||
limit_secp256k1_recovery_id, prevent_calling_precompiles_as_programs,
|
||||
return_data_syscall_enabled, secp256k1_recover_syscall_enabled,
|
||||
sol_log_data_syscall_enabled, syscall_saturated_math, update_syscall_base_costs,
|
||||
zk_token_sdk_enabled,
|
||||
},
|
||||
hash::{Hasher, HASH_BYTES},
|
||||
instruction::{
|
||||
|
@ -137,6 +138,9 @@ pub fn register_syscalls(
|
|||
let zk_token_sdk_enabled = invoke_context
|
||||
.feature_set
|
||||
.is_active(&zk_token_sdk_enabled::id());
|
||||
let curve25519_syscall_enabled = invoke_context
|
||||
.feature_set
|
||||
.is_active(&curve25519_syscall_enabled::id());
|
||||
let disable_fees_sysvar = invoke_context
|
||||
.feature_set
|
||||
.is_active(&disable_fees_sysvar::id());
|
||||
|
@ -247,6 +251,17 @@ pub fn register_syscalls(
|
|||
SyscallZkTokenElgamalOpWithScalar::call,
|
||||
)?;
|
||||
|
||||
// Elliptic Curve Point Validation
|
||||
//
|
||||
// TODO: add group operations and multiscalar multiplications
|
||||
register_feature_gated_syscall!(
|
||||
syscall_registry,
|
||||
curve25519_syscall_enabled,
|
||||
b"sol_curve25519_point_validation",
|
||||
SyscallCurvePointValidation::init,
|
||||
SyscallCurvePointValidation::call,
|
||||
)?;
|
||||
|
||||
// Sysvars
|
||||
syscall_registry.register_syscall_by_name(
|
||||
b"sol_get_clock_sysvar",
|
||||
|
@ -1890,6 +1905,80 @@ declare_syscall!(
|
|||
}
|
||||
);
|
||||
|
||||
declare_syscall!(
|
||||
// Elliptic Curve Point Validation
|
||||
//
|
||||
// Currently, only curve25519 Edwards and Ristretto representations are supported
|
||||
SyscallCurvePointValidation,
|
||||
fn call(
|
||||
&mut self,
|
||||
curve_id: u64,
|
||||
point_addr: u64,
|
||||
_arg3: u64,
|
||||
_arg4: u64,
|
||||
_arg5: u64,
|
||||
memory_mapping: &MemoryMapping,
|
||||
result: &mut Result<u64, EbpfError<BpfError>>,
|
||||
) {
|
||||
use solana_zk_token_sdk::curve25519::{curve_syscall_traits::*, edwards, ristretto};
|
||||
|
||||
let invoke_context = question_mark!(
|
||||
self.invoke_context
|
||||
.try_borrow()
|
||||
.map_err(|_| SyscallError::InvokeContextBorrowFailed),
|
||||
result
|
||||
);
|
||||
|
||||
match curve_id {
|
||||
CURVE25519_EDWARDS => {
|
||||
let cost = invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_edwards_validate_point_cost;
|
||||
question_mark!(invoke_context.get_compute_meter().consume(cost), result);
|
||||
|
||||
let point = question_mark!(
|
||||
translate_type::<edwards::PodEdwardsPoint>(
|
||||
memory_mapping,
|
||||
point_addr,
|
||||
invoke_context.get_check_aligned()
|
||||
),
|
||||
result
|
||||
);
|
||||
|
||||
if edwards::validate_edwards(point) {
|
||||
*result = Ok(0);
|
||||
} else {
|
||||
*result = Ok(1);
|
||||
}
|
||||
}
|
||||
CURVE25519_RISTRETTO => {
|
||||
let cost = invoke_context
|
||||
.get_compute_budget()
|
||||
.curve25519_ristretto_validate_point_cost;
|
||||
question_mark!(invoke_context.get_compute_meter().consume(cost), result);
|
||||
|
||||
let point = question_mark!(
|
||||
translate_type::<ristretto::PodRistrettoPoint>(
|
||||
memory_mapping,
|
||||
point_addr,
|
||||
invoke_context.get_check_aligned()
|
||||
),
|
||||
result
|
||||
);
|
||||
|
||||
if ristretto::validate_ristretto(point) {
|
||||
*result = Ok(0);
|
||||
} else {
|
||||
*result = Ok(1);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
*result = Ok(1);
|
||||
}
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
declare_syscall!(
|
||||
// Blake3
|
||||
SyscallBlake3,
|
||||
|
|
|
@ -149,6 +149,11 @@ pub mod zk_token_sdk_enabled {
|
|||
solana_sdk::declare_id!("zk1snxsc6Fh3wsGNbbHAJNHiJoYgF29mMnTSusGx5EJ");
|
||||
}
|
||||
|
||||
// TODO: temporary address for now
|
||||
pub mod curve25519_syscall_enabled {
|
||||
solana_sdk::declare_id!("curve25519111111111111111111111111111111111");
|
||||
}
|
||||
|
||||
pub mod versioned_tx_message_enabled {
|
||||
solana_sdk::declare_id!("3KZZ6Ks1885aGBQ45fwRcPXVBCtzUvxhUTkwKMR41Tca");
|
||||
}
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
//! The traits representing the basic elliptic curve operations.
|
||||
//!
|
||||
//! These traits are instantiatable by all the commonly used elliptic curves and should help in
|
||||
//! organizing syscall support for other curves in the future. more complicated or curve-specific
|
||||
//! functions that are needed in cryptographic applications should be representable by combining
|
||||
//! the associated functions of these traits.
|
||||
//!
|
||||
//! NOTE: This module temporarily lives in zk_token_sdk/curve25519, but it is independent of
|
||||
//! zk-token-sdk or curve25519. It should be moved to a more general location in the future.
|
||||
//!
|
||||
|
||||
pub trait PointValidation {
|
||||
type Point;
|
||||
|
||||
/// Verifies if a byte representation of a curve point lies in the curve.
|
||||
fn validate_point(&self) -> bool;
|
||||
}
|
||||
|
||||
pub trait GroupOperations {
|
||||
type Point;
|
||||
type Scalar;
|
||||
|
||||
/// Adds two curve points: P_0 + P_1.
|
||||
fn add(left_point: &Self::Point, right_point: &Self::Point) -> Option<Self::Point>;
|
||||
|
||||
/// Subtracts two curve points: P_0 - P_1.
|
||||
///
|
||||
/// NOTE: Altneratively, one can consider replacing this with a `negate` function that maps a
|
||||
/// curve point P -> -P. Then subtraction can be computed by combining `negate` and `add`
|
||||
/// syscalls. However, `subtract` is a much more widely used function than `negate`.
|
||||
fn subtract(left_point: &Self::Point, right_point: &Self::Point) -> Option<Self::Point>;
|
||||
|
||||
/// Multiplies a scalar S with a curve point P: S*P
|
||||
fn multiply(scalar: &Self::Scalar, point: &Self::Point) -> Option<Self::Point>;
|
||||
}
|
||||
|
||||
pub trait MultiScalarMultiplication {
|
||||
type Scalar;
|
||||
type Point;
|
||||
|
||||
/// Given a vector of scalsrs S_1, ..., S_N, and curve points P_1, ..., P_N, computes the
|
||||
/// "inner product": S_1*P_1 + ... + S_N*P_N.
|
||||
///
|
||||
/// NOTE: This operation can be represented by combining `add` and `multiply` functions in
|
||||
/// `GroupOperations`, but computing it in a single batch is significantly cheaper. Given how
|
||||
/// commonly used the multiscalar multiplication (MSM) is, it seems to make sense to have a
|
||||
/// designated trait for MSM support.
|
||||
///
|
||||
/// NOTE: The inputs to the function is a non-fixed size vector and hence, there are some
|
||||
/// complications in computing the cost for the syscall. The computational costs should only
|
||||
/// depend on the length of the vectors (and the curve), so it would be ideal to support
|
||||
/// variable length inputs and compute the syscall cost as is done in eip-197:
|
||||
/// https://github.com/ethereum/EIPs/blob/master/EIPS/eip-197.md#gas-costs. If not, then we can
|
||||
/// consider bounding the length of the input and assigning worst-case cost.
|
||||
fn multiscalar_multiply(
|
||||
scalars: &[Self::Scalar],
|
||||
points: &[Self::Point],
|
||||
) -> Option<Self::Point>;
|
||||
}
|
||||
|
||||
pub trait Pairing {
|
||||
type G1Point;
|
||||
type G2Point;
|
||||
type GTPoint;
|
||||
|
||||
/// Applies the bilinear pairing operation to two curve points P1, P2 -> e(P1, P2). This trait
|
||||
/// is only relevant for "pairing-friendly" curves such as BN254 and BLS12-381.
|
||||
fn pairing_map(
|
||||
left_point: &Self::G1Point,
|
||||
right_point: &Self::G2Point,
|
||||
) -> Option<Self::GTPoint>;
|
||||
}
|
||||
|
||||
pub const CURVE25519_EDWARDS: u64 = 0;
|
||||
pub const CURVE25519_RISTRETTO: u64 = 1;
|
||||
|
||||
pub const ADD: u64 = 0;
|
||||
pub const SUB: u64 = 1;
|
||||
pub const MUL: u64 = 2;
|
||||
|
||||
// Functions are organized by the curve traits, which can be instantiated by multiple curve
|
||||
// representations. The functions take in a `curve_id` (e.g. `CURVE25519_EDWARDS`) and should run
|
||||
// the associated functions in the appropriate trait instantiation. The `curve_op` function
|
||||
// additionally takes in an `op_id` (e.g. `ADD`) that controls which associated functions to run in
|
||||
// `GroupOperations`.
|
||||
extern "C" {
|
||||
pub fn sol_curve_validate_point(curve_id: u64, point: *const u8, result: *mut u8) -> u64;
|
||||
|
||||
pub fn sol_curve_op(curve_id: u64, op_id: u64, point: *const u8, result: *mut u8) -> u64;
|
||||
|
||||
pub fn sol_curve_multiscalar_mul(curve_id: u64, point: *const u8, result: *mut u8) -> u64;
|
||||
|
||||
pub fn sol_curve_pairing_map(curve_id: u64, point: *const u8, result: *mut u8) -> u64;
|
||||
}
|
|
@ -0,0 +1,286 @@
|
|||
use bytemuck::{Pod, Zeroable};
|
||||
pub use target_arch::*;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
|
||||
#[repr(transparent)]
|
||||
pub struct PodEdwardsPoint(pub [u8; 32]);
|
||||
|
||||
#[cfg(not(target_arch = "bpf"))]
|
||||
mod target_arch {
|
||||
use {
|
||||
super::*,
|
||||
crate::curve25519::{
|
||||
curve_syscall_traits::{GroupOperations, MultiScalarMultiplication, PointValidation},
|
||||
errors::Curve25519Error,
|
||||
scalar::PodScalar,
|
||||
},
|
||||
curve25519_dalek::{
|
||||
edwards::{CompressedEdwardsY, EdwardsPoint},
|
||||
scalar::Scalar,
|
||||
traits::VartimeMultiscalarMul,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn validate_edwards(point: &PodEdwardsPoint) -> bool {
|
||||
point.validate_point()
|
||||
}
|
||||
|
||||
pub fn add_edwards(
|
||||
left_point: &PodEdwardsPoint,
|
||||
right_point: &PodEdwardsPoint,
|
||||
) -> Option<PodEdwardsPoint> {
|
||||
PodEdwardsPoint::add(left_point, right_point)
|
||||
}
|
||||
|
||||
pub fn subtract_edwards(
|
||||
left_point: &PodEdwardsPoint,
|
||||
right_point: &PodEdwardsPoint,
|
||||
) -> Option<PodEdwardsPoint> {
|
||||
PodEdwardsPoint::subtract(left_point, right_point)
|
||||
}
|
||||
|
||||
pub fn multiply_edwards(
|
||||
scalar: &PodScalar,
|
||||
point: &PodEdwardsPoint,
|
||||
) -> Option<PodEdwardsPoint> {
|
||||
PodEdwardsPoint::multiply(scalar, point)
|
||||
}
|
||||
|
||||
pub fn multiscalar_multiply_edwards(
|
||||
scalars: &[PodScalar],
|
||||
points: &[PodEdwardsPoint],
|
||||
) -> Option<PodEdwardsPoint> {
|
||||
PodEdwardsPoint::multiscalar_multiply(scalars, points)
|
||||
}
|
||||
|
||||
impl From<&EdwardsPoint> for PodEdwardsPoint {
|
||||
fn from(point: &EdwardsPoint) -> Self {
|
||||
Self(point.compress().to_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&PodEdwardsPoint> for EdwardsPoint {
|
||||
type Error = Curve25519Error;
|
||||
|
||||
fn try_from(pod: &PodEdwardsPoint) -> Result<Self, Self::Error> {
|
||||
CompressedEdwardsY::from_slice(&pod.0)
|
||||
.decompress()
|
||||
.ok_or(Curve25519Error::PodConversion)
|
||||
}
|
||||
}
|
||||
|
||||
impl PointValidation for PodEdwardsPoint {
|
||||
type Point = Self;
|
||||
|
||||
fn validate_point(&self) -> bool {
|
||||
CompressedEdwardsY::from_slice(&self.0)
|
||||
.decompress()
|
||||
.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl GroupOperations for PodEdwardsPoint {
|
||||
type Scalar = PodScalar;
|
||||
type Point = Self;
|
||||
|
||||
fn add(left_point: &Self, right_point: &Self) -> Option<Self> {
|
||||
let left_point: EdwardsPoint = left_point.try_into().ok()?;
|
||||
let right_point: EdwardsPoint = right_point.try_into().ok()?;
|
||||
|
||||
let result = &left_point + &right_point;
|
||||
Some((&result).into())
|
||||
}
|
||||
|
||||
fn subtract(left_point: &Self, right_point: &Self) -> Option<Self> {
|
||||
let left_point: EdwardsPoint = left_point.try_into().ok()?;
|
||||
let right_point: EdwardsPoint = right_point.try_into().ok()?;
|
||||
|
||||
let result = &left_point - &right_point;
|
||||
Some((&result).into())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "bpf"))]
|
||||
fn multiply(scalar: &PodScalar, point: &Self) -> Option<Self> {
|
||||
let scalar: Scalar = scalar.into();
|
||||
let point: EdwardsPoint = point.try_into().ok()?;
|
||||
|
||||
let result = &scalar * &point;
|
||||
Some((&result).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiScalarMultiplication for PodEdwardsPoint {
|
||||
type Scalar = PodScalar;
|
||||
type Point = Self;
|
||||
|
||||
fn multiscalar_multiply(scalars: &[PodScalar], points: &[Self]) -> Option<Self> {
|
||||
EdwardsPoint::optional_multiscalar_mul(
|
||||
scalars.iter().map(Scalar::from),
|
||||
points
|
||||
.iter()
|
||||
.map(|point| EdwardsPoint::try_from(point).ok()),
|
||||
)
|
||||
.map(|result| PodEdwardsPoint::from(&result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "bpf")]
|
||||
mod target_arch {
|
||||
use {
|
||||
super::*,
|
||||
crate::curve25519::curve_syscall_traits::{sol_curve_validate_point, CURVE25519_EDWARDS},
|
||||
};
|
||||
|
||||
pub fn validate_edwards(point: &PodEdwardsPoint) -> bool {
|
||||
let mut validate_result = 0u8;
|
||||
let result = unsafe {
|
||||
sol_curve_validate_point(
|
||||
CURVE25519_EDWARDS,
|
||||
&point.0 as *const u8,
|
||||
&mut validate_result,
|
||||
)
|
||||
};
|
||||
|
||||
result == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use {
|
||||
super::*,
|
||||
crate::curve25519::scalar::PodScalar,
|
||||
curve25519_dalek::{
|
||||
constants::ED25519_BASEPOINT_POINT as G, edwards::EdwardsPoint, traits::Identity,
|
||||
},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_validate_edwards() {
|
||||
let pod = PodEdwardsPoint(G.compress().to_bytes());
|
||||
assert!(validate_edwards(&pod));
|
||||
|
||||
let invalid_bytes = [
|
||||
120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84,
|
||||
60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79,
|
||||
];
|
||||
|
||||
assert!(!validate_edwards(&PodEdwardsPoint(invalid_bytes)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edwards_add_subtract() {
|
||||
// identity
|
||||
let identity = PodEdwardsPoint(EdwardsPoint::identity().compress().to_bytes());
|
||||
let point = PodEdwardsPoint([
|
||||
201, 179, 241, 122, 180, 185, 239, 50, 183, 52, 221, 0, 153, 195, 43, 18, 22, 38, 187,
|
||||
206, 179, 192, 210, 58, 53, 45, 150, 98, 89, 17, 158, 11,
|
||||
]);
|
||||
|
||||
assert_eq!(add_edwards(&point, &identity).unwrap(), point);
|
||||
assert_eq!(subtract_edwards(&point, &identity).unwrap(), point);
|
||||
|
||||
// associativity
|
||||
let point_a = PodEdwardsPoint([
|
||||
33, 124, 71, 170, 117, 69, 151, 247, 59, 12, 95, 125, 133, 166, 64, 5, 2, 27, 90, 27,
|
||||
200, 167, 59, 164, 52, 54, 52, 200, 29, 13, 34, 213,
|
||||
]);
|
||||
let point_b = PodEdwardsPoint([
|
||||
70, 222, 137, 221, 253, 204, 71, 51, 78, 8, 124, 1, 67, 200, 102, 225, 122, 228, 111,
|
||||
183, 129, 14, 131, 210, 212, 95, 109, 246, 55, 10, 159, 91,
|
||||
]);
|
||||
let point_c = PodEdwardsPoint([
|
||||
72, 60, 66, 143, 59, 197, 111, 36, 181, 137, 25, 97, 157, 201, 247, 215, 123, 83, 220,
|
||||
250, 154, 150, 180, 192, 196, 28, 215, 137, 34, 247, 39, 129,
|
||||
]);
|
||||
|
||||
assert_eq!(
|
||||
add_edwards(&add_edwards(&point_a, &point_b).unwrap(), &point_c),
|
||||
add_edwards(&point_a, &add_edwards(&point_b, &point_c).unwrap()),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
subtract_edwards(&subtract_edwards(&point_a, &point_b).unwrap(), &point_c),
|
||||
subtract_edwards(&point_a, &add_edwards(&point_b, &point_c).unwrap()),
|
||||
);
|
||||
|
||||
// commutativity
|
||||
assert_eq!(
|
||||
add_edwards(&point_a, &point_b).unwrap(),
|
||||
add_edwards(&point_b, &point_a).unwrap(),
|
||||
);
|
||||
|
||||
// subtraction
|
||||
let point = PodEdwardsPoint(G.compress().to_bytes());
|
||||
let point_negated = PodEdwardsPoint((-G).compress().to_bytes());
|
||||
|
||||
assert_eq!(point_negated, subtract_edwards(&identity, &point).unwrap(),)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edwards_mul() {
|
||||
let scalar_a = PodScalar([
|
||||
72, 191, 131, 55, 85, 86, 54, 60, 116, 10, 39, 130, 180, 3, 90, 227, 47, 228, 252, 99,
|
||||
151, 71, 118, 29, 34, 102, 117, 114, 120, 50, 57, 8,
|
||||
]);
|
||||
let point_x = PodEdwardsPoint([
|
||||
176, 121, 6, 191, 108, 161, 206, 141, 73, 14, 235, 97, 49, 68, 48, 112, 98, 215, 145,
|
||||
208, 44, 188, 70, 10, 180, 124, 230, 15, 98, 165, 104, 85,
|
||||
]);
|
||||
let point_y = PodEdwardsPoint([
|
||||
174, 86, 89, 208, 236, 123, 223, 128, 75, 54, 228, 232, 220, 100, 205, 108, 237, 97,
|
||||
105, 79, 74, 192, 67, 224, 185, 23, 157, 116, 216, 151, 223, 81,
|
||||
]);
|
||||
|
||||
let ax = multiply_edwards(&scalar_a, &point_x).unwrap();
|
||||
let bx = multiply_edwards(&scalar_a, &point_y).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
add_edwards(&ax, &bx),
|
||||
multiply_edwards(&scalar_a, &add_edwards(&point_x, &point_y).unwrap()),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiscalar_multiplication_edwards() {
|
||||
let scalar = PodScalar([
|
||||
205, 73, 127, 173, 83, 80, 190, 66, 202, 3, 237, 77, 52, 223, 238, 70, 80, 242, 24, 87,
|
||||
111, 84, 49, 63, 194, 76, 202, 108, 62, 240, 83, 15,
|
||||
]);
|
||||
let point = PodEdwardsPoint([
|
||||
222, 174, 184, 139, 143, 122, 253, 96, 0, 207, 120, 157, 112, 38, 54, 189, 91, 144, 78,
|
||||
111, 111, 122, 140, 183, 65, 250, 191, 133, 6, 42, 212, 93,
|
||||
]);
|
||||
|
||||
let basic_product = multiply_edwards(&scalar, &point).unwrap();
|
||||
let msm_product = multiscalar_multiply_edwards(&[scalar], &[point]).unwrap();
|
||||
|
||||
assert_eq!(basic_product, msm_product);
|
||||
|
||||
let scalar_a = PodScalar([
|
||||
246, 154, 34, 110, 31, 185, 50, 1, 252, 194, 163, 56, 211, 18, 101, 192, 57, 225, 207,
|
||||
69, 19, 84, 231, 118, 137, 175, 148, 218, 106, 212, 69, 9,
|
||||
]);
|
||||
let scalar_b = PodScalar([
|
||||
27, 58, 126, 136, 253, 178, 176, 245, 246, 55, 15, 202, 35, 183, 66, 199, 134, 187,
|
||||
169, 154, 66, 120, 169, 193, 75, 4, 33, 241, 126, 227, 59, 3,
|
||||
]);
|
||||
let point_x = PodEdwardsPoint([
|
||||
252, 31, 230, 46, 173, 95, 144, 148, 158, 157, 63, 10, 8, 68, 58, 176, 142, 192, 168,
|
||||
53, 61, 105, 194, 166, 43, 56, 246, 236, 28, 146, 114, 133,
|
||||
]);
|
||||
let point_y = PodEdwardsPoint([
|
||||
10, 111, 8, 236, 97, 189, 124, 69, 89, 176, 222, 39, 199, 253, 111, 11, 248, 186, 128,
|
||||
90, 120, 128, 248, 210, 232, 183, 93, 104, 111, 150, 7, 241,
|
||||
]);
|
||||
|
||||
let ax = multiply_edwards(&scalar_a, &point_x).unwrap();
|
||||
let by = multiply_edwards(&scalar_b, &point_y).unwrap();
|
||||
let basic_product = add_edwards(&ax, &by).unwrap();
|
||||
let msm_product =
|
||||
multiscalar_multiply_edwards(&[scalar_a, scalar_b], &[point_x, point_y]).unwrap();
|
||||
|
||||
assert_eq!(basic_product, msm_product);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum Curve25519Error {
|
||||
#[error("pod conversion failed")]
|
||||
PodConversion,
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
//! Syscall operations for curve25519
|
||||
//!
|
||||
//! This module lives inside the zk-token-sdk for now, but should move to a general location since
|
||||
//! it is independent of zk-tokens.
|
||||
|
||||
pub mod curve_syscall_traits;
|
||||
pub mod edwards;
|
||||
#[cfg(not(target_arch = "bpf"))]
|
||||
pub mod errors;
|
||||
pub mod ristretto;
|
||||
pub mod scalar;
|
|
@ -0,0 +1,290 @@
|
|||
use bytemuck::{Pod, Zeroable};
|
||||
pub use target_arch::*;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
|
||||
#[repr(transparent)]
|
||||
pub struct PodRistrettoPoint(pub [u8; 32]);
|
||||
|
||||
#[cfg(not(target_arch = "bpf"))]
|
||||
mod target_arch {
|
||||
use {
|
||||
super::*,
|
||||
crate::curve25519::{
|
||||
curve_syscall_traits::{GroupOperations, MultiScalarMultiplication, PointValidation},
|
||||
errors::Curve25519Error,
|
||||
scalar::PodScalar,
|
||||
},
|
||||
curve25519_dalek::{
|
||||
ristretto::{CompressedRistretto, RistrettoPoint},
|
||||
scalar::Scalar,
|
||||
traits::VartimeMultiscalarMul,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool {
|
||||
point.validate_point()
|
||||
}
|
||||
|
||||
pub fn add_ristretto(
|
||||
left_point: &PodRistrettoPoint,
|
||||
right_point: &PodRistrettoPoint,
|
||||
) -> Option<PodRistrettoPoint> {
|
||||
PodRistrettoPoint::add(left_point, right_point)
|
||||
}
|
||||
|
||||
pub fn subtract_ristretto(
|
||||
left_point: &PodRistrettoPoint,
|
||||
right_point: &PodRistrettoPoint,
|
||||
) -> Option<PodRistrettoPoint> {
|
||||
PodRistrettoPoint::subtract(left_point, right_point)
|
||||
}
|
||||
|
||||
pub fn multiply_ristretto(
|
||||
scalar: &PodScalar,
|
||||
point: &PodRistrettoPoint,
|
||||
) -> Option<PodRistrettoPoint> {
|
||||
PodRistrettoPoint::multiply(scalar, point)
|
||||
}
|
||||
|
||||
pub fn multiscalar_multiply_ristretto(
|
||||
scalars: &[PodScalar],
|
||||
points: &[PodRistrettoPoint],
|
||||
) -> Option<PodRistrettoPoint> {
|
||||
PodRistrettoPoint::multiscalar_multiply(scalars, points)
|
||||
}
|
||||
|
||||
impl From<&RistrettoPoint> for PodRistrettoPoint {
|
||||
fn from(point: &RistrettoPoint) -> Self {
|
||||
Self(point.compress().to_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&PodRistrettoPoint> for RistrettoPoint {
|
||||
type Error = Curve25519Error;
|
||||
|
||||
fn try_from(pod: &PodRistrettoPoint) -> Result<Self, Self::Error> {
|
||||
CompressedRistretto::from_slice(&pod.0)
|
||||
.decompress()
|
||||
.ok_or(Curve25519Error::PodConversion)
|
||||
}
|
||||
}
|
||||
|
||||
impl PointValidation for PodRistrettoPoint {
|
||||
type Point = Self;
|
||||
|
||||
fn validate_point(&self) -> bool {
|
||||
CompressedRistretto::from_slice(&self.0)
|
||||
.decompress()
|
||||
.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl GroupOperations for PodRistrettoPoint {
|
||||
type Scalar = PodScalar;
|
||||
type Point = Self;
|
||||
|
||||
fn add(left_point: &Self, right_point: &Self) -> Option<Self> {
|
||||
let left_point: RistrettoPoint = left_point.try_into().ok()?;
|
||||
let right_point: RistrettoPoint = right_point.try_into().ok()?;
|
||||
|
||||
let result = &left_point + &right_point;
|
||||
Some((&result).into())
|
||||
}
|
||||
|
||||
fn subtract(left_point: &Self, right_point: &Self) -> Option<Self> {
|
||||
let left_point: RistrettoPoint = left_point.try_into().ok()?;
|
||||
let right_point: RistrettoPoint = right_point.try_into().ok()?;
|
||||
|
||||
let result = &left_point - &right_point;
|
||||
Some((&result).into())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "bpf"))]
|
||||
fn multiply(scalar: &PodScalar, point: &Self) -> Option<Self> {
|
||||
let scalar: Scalar = scalar.into();
|
||||
let point: RistrettoPoint = point.try_into().ok()?;
|
||||
|
||||
let result = &scalar * &point;
|
||||
Some((&result).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiScalarMultiplication for PodRistrettoPoint {
|
||||
type Scalar = PodScalar;
|
||||
type Point = Self;
|
||||
|
||||
fn multiscalar_multiply(scalars: &[PodScalar], points: &[Self]) -> Option<Self> {
|
||||
RistrettoPoint::optional_multiscalar_mul(
|
||||
scalars.iter().map(Scalar::from),
|
||||
points
|
||||
.iter()
|
||||
.map(|point| RistrettoPoint::try_from(point).ok()),
|
||||
)
|
||||
.map(|result| PodRistrettoPoint::from(&result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "bpf")]
|
||||
#[allow(unused_variables)]
|
||||
mod target_arch {
|
||||
use {
|
||||
super::*,
|
||||
crate::curve25519::curve_syscall_traits::{sol_curve_validate_point, CURVE25519_RISTRETTO},
|
||||
};
|
||||
|
||||
pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool {
|
||||
let mut validate_result = 0u8;
|
||||
let result = unsafe {
|
||||
sol_curve_validate_point(
|
||||
CURVE25519_RISTRETTO,
|
||||
&point.0 as *const u8,
|
||||
&mut validate_result,
|
||||
)
|
||||
};
|
||||
|
||||
result == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use {
|
||||
super::*,
|
||||
crate::curve25519::scalar::PodScalar,
|
||||
curve25519_dalek::{
|
||||
constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, traits::Identity,
|
||||
},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_validate_ristretto() {
|
||||
let pod = PodRistrettoPoint(G.compress().to_bytes());
|
||||
assert!(validate_ristretto(&pod));
|
||||
|
||||
let invalid_bytes = [
|
||||
120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84,
|
||||
60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79,
|
||||
];
|
||||
|
||||
assert!(!validate_ristretto(&PodRistrettoPoint(invalid_bytes)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_subtract_ristretto() {
|
||||
// identity
|
||||
let identity = PodRistrettoPoint(RistrettoPoint::identity().compress().to_bytes());
|
||||
let point = PodRistrettoPoint([
|
||||
210, 174, 124, 127, 67, 77, 11, 114, 71, 63, 168, 136, 113, 20, 141, 228, 195, 254,
|
||||
232, 229, 220, 249, 213, 232, 61, 238, 152, 249, 83, 225, 206, 16,
|
||||
]);
|
||||
|
||||
assert_eq!(add_ristretto(&point, &identity).unwrap(), point);
|
||||
assert_eq!(subtract_ristretto(&point, &identity).unwrap(), point);
|
||||
|
||||
// associativity
|
||||
let point_a = PodRistrettoPoint([
|
||||
208, 165, 125, 204, 2, 100, 218, 17, 170, 194, 23, 9, 102, 156, 134, 136, 217, 190, 98,
|
||||
34, 183, 194, 228, 153, 92, 11, 108, 103, 28, 57, 88, 15,
|
||||
]);
|
||||
let point_b = PodRistrettoPoint([
|
||||
208, 241, 72, 163, 73, 53, 32, 174, 54, 194, 71, 8, 70, 181, 244, 199, 93, 147, 99,
|
||||
231, 162, 127, 25, 40, 39, 19, 140, 132, 112, 212, 145, 108,
|
||||
]);
|
||||
let point_c = PodRistrettoPoint([
|
||||
250, 61, 200, 25, 195, 15, 144, 179, 24, 17, 252, 167, 247, 44, 47, 41, 104, 237, 49,
|
||||
137, 231, 173, 86, 106, 121, 249, 245, 247, 70, 188, 31, 49,
|
||||
]);
|
||||
|
||||
assert_eq!(
|
||||
add_ristretto(&add_ristretto(&point_a, &point_b).unwrap(), &point_c),
|
||||
add_ristretto(&point_a, &add_ristretto(&point_b, &point_c).unwrap()),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
subtract_ristretto(&subtract_ristretto(&point_a, &point_b).unwrap(), &point_c),
|
||||
subtract_ristretto(&point_a, &add_ristretto(&point_b, &point_c).unwrap()),
|
||||
);
|
||||
|
||||
// commutativity
|
||||
assert_eq!(
|
||||
add_ristretto(&point_a, &point_b).unwrap(),
|
||||
add_ristretto(&point_b, &point_a).unwrap(),
|
||||
);
|
||||
|
||||
// subtraction
|
||||
let point = PodRistrettoPoint(G.compress().to_bytes());
|
||||
let point_negated = PodRistrettoPoint((-G).compress().to_bytes());
|
||||
|
||||
assert_eq!(
|
||||
point_negated,
|
||||
subtract_ristretto(&identity, &point).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiply_ristretto() {
|
||||
let scalar_x = PodScalar([
|
||||
254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250,
|
||||
78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6,
|
||||
]);
|
||||
let point_a = PodRistrettoPoint([
|
||||
68, 80, 232, 181, 241, 77, 60, 81, 154, 51, 173, 35, 98, 234, 149, 37, 1, 39, 191, 201,
|
||||
193, 48, 88, 189, 97, 126, 63, 35, 144, 145, 203, 31,
|
||||
]);
|
||||
let point_b = PodRistrettoPoint([
|
||||
200, 236, 1, 12, 244, 130, 226, 214, 28, 125, 43, 163, 222, 234, 81, 213, 201, 156, 31,
|
||||
4, 167, 132, 240, 76, 164, 18, 45, 20, 48, 85, 206, 121,
|
||||
]);
|
||||
|
||||
let ax = multiply_ristretto(&scalar_x, &point_a).unwrap();
|
||||
let bx = multiply_ristretto(&scalar_x, &point_b).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
add_ristretto(&ax, &bx),
|
||||
multiply_ristretto(&scalar_x, &add_ristretto(&point_a, &point_b).unwrap()),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiscalar_multiplication_ristretto() {
|
||||
let scalar = PodScalar([
|
||||
123, 108, 109, 66, 154, 185, 88, 122, 178, 43, 17, 154, 201, 223, 31, 238, 59, 215, 71,
|
||||
154, 215, 143, 177, 158, 9, 136, 32, 223, 139, 13, 133, 5,
|
||||
]);
|
||||
let point = PodRistrettoPoint([
|
||||
158, 2, 130, 90, 148, 36, 172, 155, 86, 196, 74, 139, 30, 98, 44, 225, 155, 207, 135,
|
||||
111, 238, 167, 235, 67, 234, 125, 0, 227, 146, 31, 24, 113,
|
||||
]);
|
||||
|
||||
let basic_product = multiply_ristretto(&scalar, &point).unwrap();
|
||||
let msm_product = multiscalar_multiply_ristretto(&[scalar], &[point]).unwrap();
|
||||
|
||||
assert_eq!(basic_product, msm_product);
|
||||
|
||||
let scalar_a = PodScalar([
|
||||
8, 161, 219, 155, 192, 137, 153, 26, 27, 40, 30, 17, 124, 194, 26, 41, 32, 7, 161, 45,
|
||||
212, 198, 212, 81, 133, 185, 164, 85, 95, 232, 106, 10,
|
||||
]);
|
||||
let scalar_b = PodScalar([
|
||||
135, 207, 106, 208, 107, 127, 46, 82, 66, 22, 136, 125, 105, 62, 69, 34, 213, 210, 17,
|
||||
196, 120, 114, 238, 237, 149, 170, 5, 243, 54, 77, 172, 12,
|
||||
]);
|
||||
let point_x = PodRistrettoPoint([
|
||||
130, 35, 97, 25, 18, 199, 33, 239, 85, 143, 119, 111, 49, 51, 224, 40, 167, 185, 240,
|
||||
179, 25, 194, 213, 41, 14, 155, 104, 18, 181, 197, 15, 112,
|
||||
]);
|
||||
let point_y = PodRistrettoPoint([
|
||||
152, 156, 155, 197, 152, 232, 92, 206, 219, 159, 193, 134, 121, 128, 139, 36, 56, 191,
|
||||
51, 143, 72, 204, 87, 76, 110, 124, 101, 96, 238, 158, 42, 108,
|
||||
]);
|
||||
|
||||
let ax = multiply_ristretto(&scalar_a, &point_x).unwrap();
|
||||
let by = multiply_ristretto(&scalar_b, &point_y).unwrap();
|
||||
let basic_product = add_ristretto(&ax, &by).unwrap();
|
||||
let msm_product =
|
||||
multiscalar_multiply_ristretto(&[scalar_a, scalar_b], &[point_x, point_y]).unwrap();
|
||||
|
||||
assert_eq!(basic_product, msm_product);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
pub use bytemuck::{Pod, Zeroable};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
|
||||
#[repr(transparent)]
|
||||
pub struct PodScalar(pub [u8; 32]);
|
||||
|
||||
#[cfg(not(target_arch = "bpf"))]
|
||||
mod target_arch {
|
||||
use {super::*, curve25519_dalek::scalar::Scalar};
|
||||
|
||||
impl From<&Scalar> for PodScalar {
|
||||
fn from(scalar: &Scalar) -> Self {
|
||||
Self(scalar.to_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&PodScalar> for Scalar {
|
||||
fn from(pod: &PodScalar) -> Self {
|
||||
Scalar::from_bits(pod.0)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -32,6 +32,7 @@ mod sigma_proofs;
|
|||
mod transcript;
|
||||
|
||||
// TODO: re-organize visibility
|
||||
pub mod curve25519;
|
||||
pub mod instruction;
|
||||
pub mod zk_token_elgamal;
|
||||
pub mod zk_token_proof_instruction;
|
||||
|
|
Loading…
Reference in New Issue