solana/zk-token-sdk/src/range_proof/util.rs

139 lines
3.8 KiB
Rust

/// Utility functions for Bulletproofs.
///
/// The code is copied from https://github.com/dalek-cryptography/bulletproofs for now...
use curve25519_dalek::scalar::Scalar;
/// Represents a degree-1 vector polynomial \\(\mathbf{a} + \mathbf{b} \cdot x\\).
pub struct VecPoly1(pub Vec<Scalar>, pub Vec<Scalar>);
impl VecPoly1 {
pub fn zero(n: usize) -> Self {
VecPoly1(vec![Scalar::zero(); n], vec![Scalar::zero(); n])
}
pub fn inner_product(&self, rhs: &VecPoly1) -> Poly2 {
// Uses Karatsuba's method
let l = self;
let r = rhs;
let t0 = inner_product(&l.0, &r.0);
let t2 = inner_product(&l.1, &r.1);
let l0_plus_l1 = add_vec(&l.0, &l.1);
let r0_plus_r1 = add_vec(&r.0, &r.1);
let t1 = inner_product(&l0_plus_l1, &r0_plus_r1) - t0 - t2;
Poly2(t0, t1, t2)
}
pub fn eval(&self, x: Scalar) -> Vec<Scalar> {
let n = self.0.len();
let mut out = vec![Scalar::zero(); n];
#[allow(clippy::needless_range_loop)]
for i in 0..n {
out[i] = self.0[i] + self.1[i] * x;
}
out
}
}
/// Represents a degree-2 scalar polynomial \\(a + b \cdot x + c \cdot x^2\\)
pub struct Poly2(pub Scalar, pub Scalar, pub Scalar);
impl Poly2 {
pub fn eval(&self, x: Scalar) -> Scalar {
self.0 + x * (self.1 + x * self.2)
}
}
/// Provides an iterator over the powers of a `Scalar`.
///
/// This struct is created by the `exp_iter` function.
pub struct ScalarExp {
x: Scalar,
next_exp_x: Scalar,
}
impl Iterator for ScalarExp {
type Item = Scalar;
fn next(&mut self) -> Option<Scalar> {
let exp_x = self.next_exp_x;
self.next_exp_x *= self.x;
Some(exp_x)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::max_value(), None)
}
}
/// Return an iterator of the powers of `x`.
pub fn exp_iter(x: Scalar) -> ScalarExp {
let next_exp_x = Scalar::one();
ScalarExp { x, next_exp_x }
}
pub fn add_vec(a: &[Scalar], b: &[Scalar]) -> Vec<Scalar> {
if a.len() != b.len() {
// throw some error
//println!("lengths of vectors don't match for vector addition");
}
let mut out = vec![Scalar::zero(); b.len()];
for i in 0..a.len() {
out[i] = a[i] + b[i];
}
out
}
/// Given `data` with `len >= 32`, return the first 32 bytes.
pub fn read32(data: &[u8]) -> [u8; 32] {
let mut buf32 = [0u8; 32];
buf32[..].copy_from_slice(&data[..32]);
buf32
}
/// Computes an inner product of two vectors
/// \\[
/// {\langle {\mathbf{a}}, {\mathbf{b}} \rangle} = \sum\_{i=0}^{n-1} a\_i \cdot b\_i.
/// \\]
/// Panics if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal.
pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar {
let mut out = Scalar::zero();
if a.len() != b.len() {
panic!("inner_product(a,b): lengths of vectors do not match");
}
for i in 0..a.len() {
out += a[i] * b[i];
}
out
}
/// Takes the sum of all the powers of `x`, up to `n`
/// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplications and additions.
/// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions.
/// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2.
pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar {
if !n.is_power_of_two() {
return sum_of_powers_slow(x, n);
}
if n == 0 || n == 1 {
return Scalar::from(n as u64);
}
let mut m = n;
let mut result = Scalar::one() + x;
let mut factor = *x;
while m > 2 {
factor = factor * factor;
result = result + factor * result;
m /= 2;
}
result
}
// takes the sum of all of the powers of x, up to n
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
exp_iter(*x).take(n).sum()
}