From 227025b7b39d5853e07594a065e68abdbcc6f24c Mon Sep 17 00:00:00 2001 From: Daira Hopwood Date: Sun, 10 Jan 2021 00:23:13 +0000 Subject: [PATCH] Avoid exposing implementation details of the square root implementation. Signed-off-by: Daira Hopwood --- src/arithmetic/fields.rs | 270 +++++++++++++++++++++------------------ src/pasta/fields/fp.rs | 14 +- src/pasta/fields/fq.rs | 14 +- 3 files changed, 158 insertions(+), 140 deletions(-) diff --git a/src/arithmetic/fields.rs b/src/arithmetic/fields.rs index 1d2a0b33..07fcb880 100644 --- a/src/arithmetic/fields.rs +++ b/src/arithmetic/fields.rs @@ -5,6 +5,7 @@ use core::mem::size_of; use static_assertions::const_assert; use std::assert; use std::convert::TryInto; +use std::marker::PhantomData; use subtle::{Choice, ConstantTimeEq, CtOption}; use super::Group; @@ -41,14 +42,17 @@ pub trait FieldExt: /// Element of multiplicative order $3$. const ZETA: Self; - /// XOR parameter of the perfect hash function used for SqrtTables. - const HASH_XOR: u32; - - /// Modulus of the perfect hash function used for SqrtTables. - const HASH_MOD: usize; - - /// Tables for square root computation. - fn get_tables() -> &'static SqrtTables; + /// Computes: + /// + /// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field; + /// * (true, 0), if num is zero; + /// * (false, 0), if num is nonzero and div is zero; + /// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field; + /// + /// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare). + /// + /// The choice of root from sqrt is unspecified. + fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self); /// This computes a random element of the field using system randomness. fn rand() -> Self { @@ -76,119 +80,6 @@ pub trait FieldExt: /// byte representation of an integer. fn from_bytes_wide(bytes: &[u8; 64]) -> Self; - /// Computes: - /// - /// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field; - /// * (true, 0), if num is zero; - /// * (false, 0), if num is nonzero and div is zero; - /// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field; - /// - /// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare). - /// - /// The choice of root from sqrt is unspecified. - fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) { - // Based on: - // * [Sarkar2020](https://eprint.iacr.org/2020/1407) - // * [BDLSY2012](https://cr.yp.to/papers.html#ed25519) - // - // We need to calculate uv and v, where v = u^((m-1)/2), u = num/div, and p-1 = T * 2^S. - // We can rewrite as follows: - // - // v = (num/div)^((T-1)/2) - // = num^((T-1)/2) * div^(p-1 - (T-1)/2) [Fermat's Little Theorem] - // = " * div^(T * 2^S - (T-1)/2) - // = " * div^((2^(S+1) - 1)*(T-1)/2 + 2^S) - // = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S) - // - // Let w = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S - 1). - // Then v = w * div, and uv = num * v / div = num * w. - // - // We calculate: - // - // s = div^(2^S - 1) using an addition chain - // t = div^(2^(S+1) - 1) = s^2 * div - // w = (num * t)^((T-1)/2) * s using another addition chain - // - // then u and uv as above. The addition chains are given in - // https://github.com/zcash/pasta/blob/master/addchain_sqrt.py . - // The overall cost of this part is similar to a single full-width exponentiation, - // regardless of S. - - let sqr = |x: Self, i: u32| (0..i).fold(x, |x, _| x.square()); - - // s = div^(2^S - 1) - let s = (0..5).fold(*div, |d: Self, i| sqr(d, 1 << i) * d); - - // t == div^(2^(S+1) - 1) - let t = s.square() * div; - - // TODO: replace this with an addition chain. - let w = ff::Field::pow_vartime(&(t * num), &Self::T_MINUS1_OVER2) * s; - - // v == u^((T-1)/2) - let v = w * div; - - // uv = u * v - let uv = w * num; - - Self::sqrt_common(num, div, &uv, &v) - } - - /// Same as sqrt_ratio but given num, div, v = u^((T-1)/2), and uv = u * v as input. - /// - /// The choice of root from sqrt is unspecified. - fn sqrt_common(num: &Self, div: &Self, uv: &Self, v: &Self) -> (Choice, Self) { - let tab = Self::get_tables(); - let sqr = |x: Self, i: u32| (0..i).fold(x, |x, _| x.square()); - - let x3 = *uv * v; - let x2 = sqr(x3, 8); - let x1 = sqr(x2, 8); - let x0 = sqr(x1, 8); - - // i = 0, 1 - let mut t_: usize = tab.inv[x0.hash()] as usize; // = t >> 16 - // 1 == x0 * ROOT_OF_UNITY^(t_ << 24) - assert!(t_ < 0x100); - let alpha = x1 * tab.g2[t_]; - - // i = 2 - t_ += (tab.inv[alpha.hash()] as usize) << 8; // = t >> 8 - // 1 == x1 * ROOT_OF_UNITY^(t_ << 16) - assert!(t_ < 0x10000); - let alpha = x2 * tab.g1[t_ & 0xFF] * tab.g2[t_ >> 8]; - - // i = 3 - t_ += (tab.inv[alpha.hash()] as usize) << 16; // = t - // 1 == x2 * ROOT_OF_UNITY^(t_ << 8) - assert!(t_ < 0x1000000); - let alpha = x3 * tab.g0[t_ & 0xFF] * tab.g1[(t_ >> 8) & 0xFF] * tab.g2[t_ >> 16]; - - t_ += (tab.inv[alpha.hash()] as usize) << 24; // = t << 1 - // 1 == x3 * ROOT_OF_UNITY^t_ - t_ = (t_ + 1) >> 1; - assert!(t_ <= 0x80000000); - let res = *uv - * tab.g0[t_ & 0xFF] - * tab.g1[(t_ >> 8) & 0xFF] - * tab.g2[(t_ >> 16) & 0xFF] - * tab.g3[t_ >> 24]; - - let sqdiv = res.square() * div; - let is_square = (sqdiv - num).ct_is_zero(); - let is_nonsquare = (sqdiv - Self::ROOT_OF_UNITY * num).ct_is_zero(); - assert!(bool::from( - num.ct_is_zero() | div.ct_is_zero() | (is_square ^ is_nonsquare) - )); - - (is_square, res) - } - - /// Returns a perfect hash of this element for use with inv. - fn hash(&self) -> usize { - ((self.get_lower_32() ^ Self::HASH_XOR) as usize) % Self::HASH_MOD - } - /// Exponentiates `self` by `by`, where `by` is a little-endian order /// integer exponent. fn pow(&self, by: &[u64; 4]) -> Self { @@ -238,9 +129,25 @@ pub trait FieldExt: } } +/// Parameters for a perfect hash function used in square root computation. +#[derive(Debug)] +struct SqrtHasher { + hash_xor: u32, + hash_mod: usize, + marker: PhantomData, +} + +impl SqrtHasher { + /// Returns a perfect hash of x for use with SqrtTables::inv. + fn hash(&self, x: &F) -> usize { + ((x.get_lower_32() ^ self.hash_xor) as usize) % self.hash_mod + } +} + /// Tables used for square root computation. #[derive(Debug)] pub struct SqrtTables { + hasher: SqrtHasher, inv: Vec, g0: [F; 256], g1: [F; 256], @@ -250,7 +157,13 @@ pub struct SqrtTables { impl SqrtTables { /// Build tables given parameters for the perfect hash. - pub fn init() -> Self { + pub fn new(hash_xor: u32, hash_mod: usize) -> Self { + let hasher = SqrtHasher { + hash_xor, + hash_mod, + marker: PhantomData, + }; + let gtab: Vec> = (0..4) .scan(F::ROOT_OF_UNITY, |gi, _| { // gi == ROOT_OF_UNITY^(256^i) @@ -267,15 +180,16 @@ impl SqrtTables { .collect(); // Now invert gtab[3]. - let mut inv: Vec = vec![1; F::HASH_MOD]; + let mut inv: Vec = vec![1; hash_mod]; for j in 0..256 { - let hash = gtab[3][j].hash(); + let hash = hasher.hash(>ab[3][j]); // 1 is the last value to be assigned, so this ensures there are no collisions. assert!(inv[hash] == 1); inv[hash] = ((256 - j) & 0xFF) as u8; } SqrtTables:: { + hasher, inv, g0: gtab[0][..].try_into().unwrap(), g1: gtab[1][..].try_into().unwrap(), @@ -283,6 +197,114 @@ impl SqrtTables { g3: gtab[3][0..129].try_into().unwrap(), } } + + /// Computes: + /// + /// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field; + /// * (true, 0), if num is zero; + /// * (false, 0), if num is nonzero and div is zero; + /// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field; + /// + /// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare). + /// + /// The choice of root from sqrt is unspecified. + pub fn sqrt_ratio(&self, num: &F, div: &F) -> (Choice, F) { + // Based on: + // * [Sarkar2020](https://eprint.iacr.org/2020/1407) + // * [BDLSY2012](https://cr.yp.to/papers.html#ed25519) + // + // We need to calculate uv and v, where v = u^((m-1)/2), u = num/div, and p-1 = T * 2^S. + // We can rewrite as follows: + // + // v = (num/div)^((T-1)/2) + // = num^((T-1)/2) * div^(p-1 - (T-1)/2) [Fermat's Little Theorem] + // = " * div^(T * 2^S - (T-1)/2) + // = " * div^((2^(S+1) - 1)*(T-1)/2 + 2^S) + // = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S) + // + // Let w = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S - 1). + // Then v = w * div, and uv = num * v / div = num * w. + // + // We calculate: + // + // s = div^(2^S - 1) using an addition chain + // t = div^(2^(S+1) - 1) = s^2 * div + // w = (num * t)^((T-1)/2) * s using another addition chain + // + // then u and uv as above. The addition chains are given in + // https://github.com/zcash/pasta/blob/master/addchain_sqrt.py . + // The overall cost of this part is similar to a single full-width exponentiation, + // regardless of S. + + let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square()); + + // s = div^(2^S - 1) + let s = (0..5).fold(*div, |d: F, i| sqr(d, 1 << i) * d); + + // t == div^(2^(S+1) - 1) + let t = s.square() * div; + + // TODO: replace this with an addition chain. + let w = ff::Field::pow_vartime(&(t * num), &F::T_MINUS1_OVER2) * s; + + // v == u^((T-1)/2) + let v = w * div; + + // uv = u * v + let uv = w * num; + + self.sqrt_common(num, div, &uv, &v) + } + + /// Same as sqrt_ratio but given num, div, v = u^((T-1)/2), and uv = u * v as input. + /// + /// The choice of root from sqrt is unspecified. + fn sqrt_common(&self, num: &F, div: &F, uv: &F, v: &F) -> (Choice, F) { + let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square()); + let inv = |x: F| self.inv[self.hasher.hash(&x)] as usize; + + let x3 = *uv * v; + let x2 = sqr(x3, 8); + let x1 = sqr(x2, 8); + let x0 = sqr(x1, 8); + + // i = 0, 1 + let mut t_ = inv(x0); // = t >> 16 + // 1 == x0 * ROOT_OF_UNITY^(t_ << 24) + assert!(t_ < 0x100); + let alpha = x1 * self.g2[t_]; + + // i = 2 + t_ += inv(alpha) << 8; // = t >> 8 + // 1 == x1 * ROOT_OF_UNITY^(t_ << 16) + assert!(t_ < 0x10000); + let alpha = x2 * self.g1[t_ & 0xFF] * self.g2[t_ >> 8]; + + // i = 3 + t_ += inv(alpha) << 16; // = t + // 1 == x2 * ROOT_OF_UNITY^(t_ << 8) + assert!(t_ < 0x1000000); + let alpha = x3 * self.g0[t_ & 0xFF] * self.g1[(t_ >> 8) & 0xFF] * self.g2[t_ >> 16]; + + t_ += inv(alpha) << 24; // = t << 1 + // 1 == x3 * ROOT_OF_UNITY^t_ + t_ = (t_ + 1) >> 1; + assert!(t_ <= 0x80000000); + let res = *uv + * self.g0[t_ & 0xFF] + * self.g1[(t_ >> 8) & 0xFF] + * self.g2[(t_ >> 16) & 0xFF] + * self.g3[t_ >> 24]; + + let sqdiv = res.square() * div; + let is_square = (sqdiv - num).ct_is_zero(); + let is_nonsquare = (sqdiv - F::ROOT_OF_UNITY * num).ct_is_zero(); + assert!(bool::from( + num.ct_is_zero() | div.ct_is_zero() | (is_square ^ is_nonsquare) + )); + + (is_square, res) + } } /// Compute a + b + carry, returning the result and the new carry over. diff --git a/src/pasta/fields/fp.rs b/src/pasta/fields/fp.rs index 3dae877f..0857de10 100644 --- a/src/pasta/fields/fp.rs +++ b/src/pasta/fields/fp.rs @@ -636,6 +636,10 @@ impl ff::PrimeField for Fp { } } +lazy_static! { + static ref FP_TABLES: SqrtTables = SqrtTables::new(0x11BE, 1098); +} + impl FieldExt for Fp { const ROOT_OF_UNITY: Self = ROOT_OF_UNITY; const ROOT_OF_UNITY_INV: Self = Fp::from_raw([ @@ -671,14 +675,8 @@ impl FieldExt for Fp { 0x12ccca834acdba71, ]); - const HASH_XOR: u32 = 0x11BE; - const HASH_MOD: usize = 1098; - - fn get_tables() -> &'static SqrtTables { - lazy_static! { - static ref FP_TABLES: SqrtTables = SqrtTables::init(); - } - &FP_TABLES + fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) { + FP_TABLES.sqrt_ratio(num, div) } fn ct_is_zero(&self) -> Choice { diff --git a/src/pasta/fields/fq.rs b/src/pasta/fields/fq.rs index c93ac54f..9d80a81c 100644 --- a/src/pasta/fields/fq.rs +++ b/src/pasta/fields/fq.rs @@ -636,6 +636,10 @@ impl ff::PrimeField for Fq { } } +lazy_static! { + static ref FQ_TABLES: SqrtTables = SqrtTables::new(0x116A9E, 1206); +} + impl FieldExt for Fq { const ROOT_OF_UNITY: Self = ROOT_OF_UNITY; const ROOT_OF_UNITY_INV: Self = Fq::from_raw([ @@ -671,14 +675,8 @@ impl FieldExt for Fq { 0x06819a58283e528e, ]); - const HASH_XOR: u32 = 0x116A9E; - const HASH_MOD: usize = 1206; - - fn get_tables() -> &'static SqrtTables { - lazy_static! { - static ref FQ_TABLES: SqrtTables = SqrtTables::init(); - } - &FQ_TABLES + fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) { + FQ_TABLES.sqrt_ratio(num, div) } fn ct_is_zero(&self) -> Choice {