[zk-token-sdk] Add range proof generation error types (#34065)
* replace assert statements with `VectorLengthMismatch` error variant * add a condition to check that the bit lengths are in the correct range * replace assert statements with `GeneratorLengthMismatch` * remove unchecked arithmetic * add `InnerProductLengthMismatch` error * fix typo * add a clarifying comment on unwrap safety * fix typo
This commit is contained in:
parent
ecc067f7ad
commit
ded278fb57
|
@ -47,7 +47,7 @@ impl BatchedRangeProofU128Data {
|
|||
.try_fold(0_usize, |acc, &x| acc.checked_add(x))
|
||||
.ok_or(ProofGenerationError::IllegalAmountBitLength)?;
|
||||
|
||||
// `u64::BITS` is 128, which fits in a single byte and should not overflow to `usize` for
|
||||
// `u128::BITS` is 128, which fits in a single byte and should not overflow to `usize` for
|
||||
// an overwhelming number of platforms. However, to be extra cautious, use `try_from` and
|
||||
// `unwrap` here. A simple case `u128::BITS as usize` can silently overflow.
|
||||
let expected_bit_length = usize::try_from(u128::BITS).unwrap();
|
||||
|
|
|
@ -5,6 +5,14 @@ use {crate::errors::TranscriptError, thiserror::Error};
|
|||
pub enum RangeProofGenerationError {
|
||||
#[error("maximum generator length exceeded")]
|
||||
MaximumGeneratorLengthExceeded,
|
||||
#[error("amounts, commitments, openings, or bit lengths vectors have different lengths")]
|
||||
VectorLengthMismatch,
|
||||
#[error("invalid bit size")]
|
||||
InvalidBitSize,
|
||||
#[error("insufficient generators for the proof")]
|
||||
GeneratorLengthMismatch,
|
||||
#[error("inner product length mismatch")]
|
||||
InnerProductLengthMismatch,
|
||||
}
|
||||
|
||||
#[derive(Error, Clone, Debug, Eq, PartialEq)]
|
||||
|
@ -25,6 +33,8 @@ pub enum RangeProofVerificationError {
|
|||
InvalidGeneratorsLength,
|
||||
#[error("maximum generator length exceeded")]
|
||||
MaximumGeneratorLengthExceeded,
|
||||
#[error("commitments and bit lengths vectors have different lengths")]
|
||||
VectorLengthMismatch,
|
||||
}
|
||||
|
||||
#[derive(Error, Clone, Debug, Eq, PartialEq)]
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use {
|
||||
crate::{
|
||||
range_proof::{errors::RangeProofVerificationError, util},
|
||||
range_proof::{
|
||||
errors::{RangeProofGenerationError, RangeProofVerificationError},
|
||||
util,
|
||||
},
|
||||
transcript::TranscriptProtocol,
|
||||
},
|
||||
core::iter,
|
||||
|
@ -45,7 +48,7 @@ impl InnerProductProof {
|
|||
mut a_vec: Vec<Scalar>,
|
||||
mut b_vec: Vec<Scalar>,
|
||||
transcript: &mut Transcript,
|
||||
) -> Self {
|
||||
) -> Result<Self, RangeProofGenerationError> {
|
||||
// Create slices G, H, a, b backed by their respective
|
||||
// vectors. This lets us reslice as we compress the lengths
|
||||
// of the vectors in the main loop below.
|
||||
|
@ -57,15 +60,20 @@ impl InnerProductProof {
|
|||
let mut n = G.len();
|
||||
|
||||
// All of the input vectors must have the same length.
|
||||
assert_eq!(G.len(), n);
|
||||
assert_eq!(H.len(), n);
|
||||
assert_eq!(a.len(), n);
|
||||
assert_eq!(b.len(), n);
|
||||
assert_eq!(G_factors.len(), n);
|
||||
assert_eq!(H_factors.len(), n);
|
||||
if G.len() != n
|
||||
|| H.len() != n
|
||||
|| a.len() != n
|
||||
|| b.len() != n
|
||||
|| G_factors.len() != n
|
||||
|| H_factors.len() != n
|
||||
{
|
||||
return Err(RangeProofGenerationError::GeneratorLengthMismatch);
|
||||
}
|
||||
|
||||
// All of the input vectors must have a length that is a power of two.
|
||||
assert!(n.is_power_of_two());
|
||||
if !n.is_power_of_two() {
|
||||
return Err(RangeProofGenerationError::InvalidBitSize);
|
||||
}
|
||||
|
||||
transcript.innerproduct_domain_separator(n as u64);
|
||||
|
||||
|
@ -76,18 +84,21 @@ impl InnerProductProof {
|
|||
// If it's the first iteration, unroll the Hprime = H*y_inv scalar mults
|
||||
// into multiscalar muls, for performance.
|
||||
if n != 1 {
|
||||
n /= 2;
|
||||
n = n.checked_div(2).unwrap();
|
||||
let (a_L, a_R) = a.split_at_mut(n);
|
||||
let (b_L, b_R) = b.split_at_mut(n);
|
||||
let (G_L, G_R) = G.split_at_mut(n);
|
||||
let (H_L, H_R) = H.split_at_mut(n);
|
||||
|
||||
let c_L = util::inner_product(a_L, b_R);
|
||||
let c_R = util::inner_product(a_R, b_L);
|
||||
let c_L = util::inner_product(a_L, b_R)
|
||||
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;
|
||||
let c_R = util::inner_product(a_R, b_L)
|
||||
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;
|
||||
|
||||
let L = RistrettoPoint::multiscalar_mul(
|
||||
a_L.iter()
|
||||
.zip(G_factors[n..2 * n].iter())
|
||||
// `n` was previously divided in half and therefore, it cannot overflow.
|
||||
.zip(G_factors[n..n.checked_mul(2).unwrap()].iter())
|
||||
.map(|(a_L_i, g)| a_L_i * g)
|
||||
.chain(
|
||||
b_R.iter()
|
||||
|
@ -105,7 +116,7 @@ impl InnerProductProof {
|
|||
.map(|(a_R_i, g)| a_R_i * g)
|
||||
.chain(
|
||||
b_L.iter()
|
||||
.zip(H_factors[n..2 * n].iter())
|
||||
.zip(H_factors[n..n.checked_mul(2).unwrap()].iter())
|
||||
.map(|(b_L_i, h)| b_L_i * h),
|
||||
)
|
||||
.chain(iter::once(c_R)),
|
||||
|
@ -126,11 +137,17 @@ impl InnerProductProof {
|
|||
a_L[i] = a_L[i] * u + u_inv * a_R[i];
|
||||
b_L[i] = b_L[i] * u_inv + u * b_R[i];
|
||||
G_L[i] = RistrettoPoint::multiscalar_mul(
|
||||
&[u_inv * G_factors[i], u * G_factors[n + i]],
|
||||
&[
|
||||
u_inv * G_factors[i],
|
||||
u * G_factors[n.checked_add(i).unwrap()],
|
||||
],
|
||||
&[G_L[i], G_R[i]],
|
||||
);
|
||||
H_L[i] = RistrettoPoint::multiscalar_mul(
|
||||
&[u * H_factors[i], u_inv * H_factors[n + i]],
|
||||
&[
|
||||
u * H_factors[i],
|
||||
u_inv * H_factors[n.checked_add(i).unwrap()],
|
||||
],
|
||||
&[H_L[i], H_R[i]],
|
||||
)
|
||||
}
|
||||
|
@ -142,14 +159,16 @@ impl InnerProductProof {
|
|||
}
|
||||
|
||||
while n != 1 {
|
||||
n /= 2;
|
||||
n = n.checked_div(2).unwrap();
|
||||
let (a_L, a_R) = a.split_at_mut(n);
|
||||
let (b_L, b_R) = b.split_at_mut(n);
|
||||
let (G_L, G_R) = G.split_at_mut(n);
|
||||
let (H_L, H_R) = H.split_at_mut(n);
|
||||
|
||||
let c_L = util::inner_product(a_L, b_R);
|
||||
let c_R = util::inner_product(a_R, b_L);
|
||||
let c_L = util::inner_product(a_L, b_R)
|
||||
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;
|
||||
let c_R = util::inner_product(a_R, b_L)
|
||||
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;
|
||||
|
||||
let L = RistrettoPoint::multiscalar_mul(
|
||||
a_L.iter().chain(b_R.iter()).chain(iter::once(&c_L)),
|
||||
|
@ -185,12 +204,12 @@ impl InnerProductProof {
|
|||
H = H_L;
|
||||
}
|
||||
|
||||
InnerProductProof {
|
||||
Ok(InnerProductProof {
|
||||
L_vec,
|
||||
R_vec,
|
||||
a: a[0],
|
||||
b: b[0],
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and
|
||||
|
@ -210,7 +229,7 @@ impl InnerProductProof {
|
|||
// and this check prevents overflow in 1<<lg_n below.
|
||||
return Err(RangeProofVerificationError::InvalidBitSize);
|
||||
}
|
||||
if n != (1 << lg_n) {
|
||||
if n != (1_usize.checked_shl(lg_n as u32).unwrap()) {
|
||||
return Err(RangeProofVerificationError::InvalidBitSize);
|
||||
}
|
||||
|
||||
|
@ -244,11 +263,14 @@ impl InnerProductProof {
|
|||
let mut s = Vec::with_capacity(n);
|
||||
s.push(allinv);
|
||||
for i in 1..n {
|
||||
let lg_i = (32 - 1 - (i as u32).leading_zeros()) as usize;
|
||||
let k = 1 << lg_i;
|
||||
let lg_i = 31_u32.checked_sub((i as u32).leading_zeros()).unwrap() as usize;
|
||||
let k = 1_usize.checked_shl(lg_i as u32).unwrap();
|
||||
// The challenges are stored in "creation order" as [u_k,...,u_1],
|
||||
// so u_{lg(i)+1} = is indexed by (lg_n-1) - lg_i
|
||||
let u_lg_i_sq = challenges_sq[(lg_n - 1) - lg_i];
|
||||
let u_lg_i_sq = challenges_sq[lg_n
|
||||
.checked_sub(1)
|
||||
.and_then(|x| x.checked_sub(lg_i))
|
||||
.unwrap()];
|
||||
s.push(s[i - k] * u_lg_i_sq);
|
||||
}
|
||||
|
||||
|
@ -418,7 +440,7 @@ mod tests {
|
|||
|
||||
let a: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect();
|
||||
let b: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect();
|
||||
let c = util::inner_product(&a, &b);
|
||||
let c = util::inner_product(&a, &b).unwrap();
|
||||
|
||||
let G_factors: Vec<Scalar> = iter::repeat(Scalar::one()).take(n).collect();
|
||||
|
||||
|
@ -451,7 +473,8 @@ mod tests {
|
|||
a.clone(),
|
||||
b.clone(),
|
||||
&mut prover_transcript,
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(proof
|
||||
.verify(
|
||||
|
|
|
@ -75,12 +75,23 @@ impl RangeProof {
|
|||
) -> Result<Self, RangeProofGenerationError> {
|
||||
// amounts, bit-lengths, openings must be same length vectors
|
||||
let m = amounts.len();
|
||||
assert_eq!(bit_lengths.len(), m);
|
||||
assert_eq!(openings.len(), m);
|
||||
if bit_lengths.len() != m || openings.len() != m {
|
||||
return Err(RangeProofGenerationError::VectorLengthMismatch);
|
||||
}
|
||||
|
||||
// each bit length must be greater than 0 for the proof to make sense
|
||||
if bit_lengths
|
||||
.iter()
|
||||
.any(|bit_length| *bit_length == 0 || *bit_length > u64::BITS as usize)
|
||||
{
|
||||
return Err(RangeProofGenerationError::InvalidBitSize);
|
||||
}
|
||||
|
||||
// total vector dimension to compute the ultimate inner product proof for
|
||||
let nm: usize = bit_lengths.iter().sum();
|
||||
assert!(nm.is_power_of_two());
|
||||
if !nm.is_power_of_two() {
|
||||
return Err(RangeProofGenerationError::VectorLengthMismatch);
|
||||
}
|
||||
|
||||
let bp_gens = BulletproofGens::new(nm)
|
||||
.map_err(|_| RangeProofGenerationError::MaximumGeneratorLengthExceeded)?;
|
||||
|
@ -93,7 +104,10 @@ impl RangeProof {
|
|||
for (amount_i, n_i) in amounts.iter().zip(bit_lengths.iter()) {
|
||||
for j in 0..(*n_i) {
|
||||
let (G_ij, H_ij) = gens_iter.next().unwrap();
|
||||
let v_ij = Choice::from(((amount_i >> j) & 1) as u8);
|
||||
|
||||
// `j` is guaranteed to be at most `u64::BITS` (a 6-bit number) and therefore,
|
||||
// casting is lossless and right shift can be safely unwrapped
|
||||
let v_ij = Choice::from((amount_i.checked_shr(j as u32).unwrap() & 1) as u8);
|
||||
let mut point = -H_ij;
|
||||
point.conditional_assign(G_ij, v_ij);
|
||||
A += point;
|
||||
|
@ -138,7 +152,9 @@ impl RangeProof {
|
|||
let mut exp_2 = Scalar::one();
|
||||
|
||||
for j in 0..(*n_i) {
|
||||
let a_L_j = Scalar::from((amount_i >> j) & 1);
|
||||
// `j` is guaranteed to be at most `u64::BITS` (a 6-bit number) and therefore,
|
||||
// casting is lossless and right shift can be safely unwrapped
|
||||
let a_L_j = Scalar::from(amount_i.checked_shr(j as u32).unwrap() & 1);
|
||||
let a_R_j = a_L_j - Scalar::one();
|
||||
|
||||
l_poly.0[i] = a_L_j - z;
|
||||
|
@ -148,13 +164,17 @@ impl RangeProof {
|
|||
|
||||
exp_y *= y;
|
||||
exp_2 = exp_2 + exp_2;
|
||||
i += 1;
|
||||
|
||||
// `i` is capped by the sum of vectors in `bit_lengths`
|
||||
i = i.checked_add(1).unwrap();
|
||||
}
|
||||
exp_z *= z;
|
||||
}
|
||||
|
||||
// define t(x) = <l(x), r(x)> = t_0 + t_1*x + t_2*x
|
||||
let t_poly = l_poly.inner_product(&r_poly);
|
||||
let t_poly = l_poly
|
||||
.inner_product(&r_poly)
|
||||
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;
|
||||
|
||||
// generate Pedersen commitment for the coefficients t_1 and t_2
|
||||
let (T_1, t_1_blinding) = Pedersen::new(t_poly.1);
|
||||
|
@ -216,7 +236,7 @@ impl RangeProof {
|
|||
l_vec,
|
||||
r_vec,
|
||||
transcript,
|
||||
);
|
||||
)?;
|
||||
|
||||
Ok(RangeProof {
|
||||
A,
|
||||
|
@ -238,7 +258,9 @@ impl RangeProof {
|
|||
transcript: &mut Transcript,
|
||||
) -> Result<(), RangeProofVerificationError> {
|
||||
// commitments and bit-lengths must be same length vectors
|
||||
assert_eq!(comms.len(), bit_lengths.len());
|
||||
if comms.len() != bit_lengths.len() {
|
||||
return Err(RangeProofVerificationError::VectorLengthMismatch);
|
||||
}
|
||||
|
||||
let m = bit_lengths.len();
|
||||
let nm: usize = bit_lengths.iter().sum();
|
||||
|
|
|
@ -11,20 +11,20 @@ impl VecPoly1 {
|
|||
VecPoly1(vec![Scalar::zero(); n], vec![Scalar::zero(); n])
|
||||
}
|
||||
|
||||
pub fn inner_product(&self, rhs: &VecPoly1) -> Poly2 {
|
||||
pub fn inner_product(&self, rhs: &VecPoly1) -> Option<Poly2> {
|
||||
// Uses Karatsuba's method
|
||||
let l = self;
|
||||
let r = rhs;
|
||||
|
||||
let t0 = inner_product(&l.0, &r.0);
|
||||
let t2 = inner_product(&l.1, &r.1);
|
||||
let t0 = inner_product(&l.0, &r.0)?;
|
||||
let t2 = inner_product(&l.1, &r.1)?;
|
||||
|
||||
let l0_plus_l1 = add_vec(&l.0, &l.1);
|
||||
let r0_plus_r1 = add_vec(&r.0, &r.1);
|
||||
|
||||
let t1 = inner_product(&l0_plus_l1, &r0_plus_r1) - t0 - t2;
|
||||
let t1 = inner_product(&l0_plus_l1, &r0_plus_r1)? - t0 - t2;
|
||||
|
||||
Poly2(t0, t1, t2)
|
||||
Some(Poly2(t0, t1, t2))
|
||||
}
|
||||
|
||||
pub fn eval(&self, x: Scalar) -> Vec<Scalar> {
|
||||
|
@ -98,16 +98,16 @@ pub fn read32(data: &[u8]) -> [u8; 32] {
|
|||
/// \\[
|
||||
/// {\langle {\mathbf{a}}, {\mathbf{b}} \rangle} = \sum\_{i=0}^{n-1} a\_i \cdot b\_i.
|
||||
/// \\]
|
||||
/// Panics if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal.
|
||||
pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar {
|
||||
/// Errors if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal.
|
||||
pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Option<Scalar> {
|
||||
let mut out = Scalar::zero();
|
||||
if a.len() != b.len() {
|
||||
panic!("inner_product(a,b): lengths of vectors do not match");
|
||||
return None;
|
||||
}
|
||||
for i in 0..a.len() {
|
||||
out += a[i] * b[i];
|
||||
}
|
||||
out
|
||||
Some(out)
|
||||
}
|
||||
|
||||
/// Takes the sum of all the powers of `x`, up to `n`
|
||||
|
|
Loading…
Reference in New Issue