Avoid exposing implementation details of the square root implementation.

Signed-off-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
Daira Hopwood 2021-01-10 00:23:13 +00:00
parent e13ee2c8ff
commit 227025b7b3
3 changed files with 158 additions and 140 deletions

View File

@ -5,6 +5,7 @@ use core::mem::size_of;
use static_assertions::const_assert;
use std::assert;
use std::convert::TryInto;
use std::marker::PhantomData;
use subtle::{Choice, ConstantTimeEq, CtOption};
use super::Group;
@ -41,14 +42,17 @@ pub trait FieldExt:
/// Element of multiplicative order $3$.
const ZETA: Self;
/// XOR parameter of the perfect hash function used for SqrtTables.
const HASH_XOR: u32;
/// Modulus of the perfect hash function used for SqrtTables.
const HASH_MOD: usize;
/// Tables for square root computation.
fn get_tables() -> &'static SqrtTables<Self>;
/// Computes:
///
/// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field;
/// * (true, 0), if num is zero;
/// * (false, 0), if num is nonzero and div is zero;
/// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field;
///
/// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare).
///
/// The choice of root from sqrt is unspecified.
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self);
/// This computes a random element of the field using system randomness.
fn rand() -> Self {
@ -76,119 +80,6 @@ pub trait FieldExt:
/// byte representation of an integer.
fn from_bytes_wide(bytes: &[u8; 64]) -> Self;
/// Computes:
///
/// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field;
/// * (true, 0), if num is zero;
/// * (false, 0), if num is nonzero and div is zero;
/// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field;
///
/// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare).
///
/// The choice of root from sqrt is unspecified.
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
// Based on:
// * [Sarkar2020](https://eprint.iacr.org/2020/1407)
// * [BDLSY2012](https://cr.yp.to/papers.html#ed25519)
//
// We need to calculate uv and v, where v = u^((m-1)/2), u = num/div, and p-1 = T * 2^S.
// We can rewrite as follows:
//
// v = (num/div)^((T-1)/2)
// = num^((T-1)/2) * div^(p-1 - (T-1)/2) [Fermat's Little Theorem]
// = " * div^(T * 2^S - (T-1)/2)
// = " * div^((2^(S+1) - 1)*(T-1)/2 + 2^S)
// = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S)
//
// Let w = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S - 1).
// Then v = w * div, and uv = num * v / div = num * w.
//
// We calculate:
//
// s = div^(2^S - 1) using an addition chain
// t = div^(2^(S+1) - 1) = s^2 * div
// w = (num * t)^((T-1)/2) * s using another addition chain
//
// then u and uv as above. The addition chains are given in
// https://github.com/zcash/pasta/blob/master/addchain_sqrt.py .
// The overall cost of this part is similar to a single full-width exponentiation,
// regardless of S.
let sqr = |x: Self, i: u32| (0..i).fold(x, |x, _| x.square());
// s = div^(2^S - 1)
let s = (0..5).fold(*div, |d: Self, i| sqr(d, 1 << i) * d);
// t == div^(2^(S+1) - 1)
let t = s.square() * div;
// TODO: replace this with an addition chain.
let w = ff::Field::pow_vartime(&(t * num), &Self::T_MINUS1_OVER2) * s;
// v == u^((T-1)/2)
let v = w * div;
// uv = u * v
let uv = w * num;
Self::sqrt_common(num, div, &uv, &v)
}
/// Same as sqrt_ratio but given num, div, v = u^((T-1)/2), and uv = u * v as input.
///
/// The choice of root from sqrt is unspecified.
fn sqrt_common(num: &Self, div: &Self, uv: &Self, v: &Self) -> (Choice, Self) {
let tab = Self::get_tables();
let sqr = |x: Self, i: u32| (0..i).fold(x, |x, _| x.square());
let x3 = *uv * v;
let x2 = sqr(x3, 8);
let x1 = sqr(x2, 8);
let x0 = sqr(x1, 8);
// i = 0, 1
let mut t_: usize = tab.inv[x0.hash()] as usize; // = t >> 16
// 1 == x0 * ROOT_OF_UNITY^(t_ << 24)
assert!(t_ < 0x100);
let alpha = x1 * tab.g2[t_];
// i = 2
t_ += (tab.inv[alpha.hash()] as usize) << 8; // = t >> 8
// 1 == x1 * ROOT_OF_UNITY^(t_ << 16)
assert!(t_ < 0x10000);
let alpha = x2 * tab.g1[t_ & 0xFF] * tab.g2[t_ >> 8];
// i = 3
t_ += (tab.inv[alpha.hash()] as usize) << 16; // = t
// 1 == x2 * ROOT_OF_UNITY^(t_ << 8)
assert!(t_ < 0x1000000);
let alpha = x3 * tab.g0[t_ & 0xFF] * tab.g1[(t_ >> 8) & 0xFF] * tab.g2[t_ >> 16];
t_ += (tab.inv[alpha.hash()] as usize) << 24; // = t << 1
// 1 == x3 * ROOT_OF_UNITY^t_
t_ = (t_ + 1) >> 1;
assert!(t_ <= 0x80000000);
let res = *uv
* tab.g0[t_ & 0xFF]
* tab.g1[(t_ >> 8) & 0xFF]
* tab.g2[(t_ >> 16) & 0xFF]
* tab.g3[t_ >> 24];
let sqdiv = res.square() * div;
let is_square = (sqdiv - num).ct_is_zero();
let is_nonsquare = (sqdiv - Self::ROOT_OF_UNITY * num).ct_is_zero();
assert!(bool::from(
num.ct_is_zero() | div.ct_is_zero() | (is_square ^ is_nonsquare)
));
(is_square, res)
}
/// Returns a perfect hash of this element for use with inv.
fn hash(&self) -> usize {
((self.get_lower_32() ^ Self::HASH_XOR) as usize) % Self::HASH_MOD
}
/// Exponentiates `self` by `by`, where `by` is a little-endian order
/// integer exponent.
fn pow(&self, by: &[u64; 4]) -> Self {
@ -238,9 +129,25 @@ pub trait FieldExt:
}
}
/// Parameters for a perfect hash function used in square root computation.
#[derive(Debug)]
struct SqrtHasher<F: FieldExt> {
hash_xor: u32,
hash_mod: usize,
marker: PhantomData<F>,
}
impl<F: FieldExt> SqrtHasher<F> {
/// Returns a perfect hash of x for use with SqrtTables::inv.
fn hash(&self, x: &F) -> usize {
((x.get_lower_32() ^ self.hash_xor) as usize) % self.hash_mod
}
}
/// Tables used for square root computation.
#[derive(Debug)]
pub struct SqrtTables<F: FieldExt> {
hasher: SqrtHasher<F>,
inv: Vec<u8>,
g0: [F; 256],
g1: [F; 256],
@ -250,7 +157,13 @@ pub struct SqrtTables<F: FieldExt> {
impl<F: FieldExt> SqrtTables<F> {
/// Build tables given parameters for the perfect hash.
pub fn init() -> Self {
pub fn new(hash_xor: u32, hash_mod: usize) -> Self {
let hasher = SqrtHasher {
hash_xor,
hash_mod,
marker: PhantomData,
};
let gtab: Vec<Vec<F>> = (0..4)
.scan(F::ROOT_OF_UNITY, |gi, _| {
// gi == ROOT_OF_UNITY^(256^i)
@ -267,15 +180,16 @@ impl<F: FieldExt> SqrtTables<F> {
.collect();
// Now invert gtab[3].
let mut inv: Vec<u8> = vec![1; F::HASH_MOD];
let mut inv: Vec<u8> = vec![1; hash_mod];
for j in 0..256 {
let hash = gtab[3][j].hash();
let hash = hasher.hash(&gtab[3][j]);
// 1 is the last value to be assigned, so this ensures there are no collisions.
assert!(inv[hash] == 1);
inv[hash] = ((256 - j) & 0xFF) as u8;
}
SqrtTables::<F> {
hasher,
inv,
g0: gtab[0][..].try_into().unwrap(),
g1: gtab[1][..].try_into().unwrap(),
@ -283,6 +197,114 @@ impl<F: FieldExt> SqrtTables<F> {
g3: gtab[3][0..129].try_into().unwrap(),
}
}
/// Computes:
///
/// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field;
/// * (true, 0), if num is zero;
/// * (false, 0), if num is nonzero and div is zero;
/// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field;
///
/// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare).
///
/// The choice of root from sqrt is unspecified.
pub fn sqrt_ratio(&self, num: &F, div: &F) -> (Choice, F) {
// Based on:
// * [Sarkar2020](https://eprint.iacr.org/2020/1407)
// * [BDLSY2012](https://cr.yp.to/papers.html#ed25519)
//
// We need to calculate uv and v, where v = u^((m-1)/2), u = num/div, and p-1 = T * 2^S.
// We can rewrite as follows:
//
// v = (num/div)^((T-1)/2)
// = num^((T-1)/2) * div^(p-1 - (T-1)/2) [Fermat's Little Theorem]
// = " * div^(T * 2^S - (T-1)/2)
// = " * div^((2^(S+1) - 1)*(T-1)/2 + 2^S)
// = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S)
//
// Let w = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S - 1).
// Then v = w * div, and uv = num * v / div = num * w.
//
// We calculate:
//
// s = div^(2^S - 1) using an addition chain
// t = div^(2^(S+1) - 1) = s^2 * div
// w = (num * t)^((T-1)/2) * s using another addition chain
//
// then u and uv as above. The addition chains are given in
// https://github.com/zcash/pasta/blob/master/addchain_sqrt.py .
// The overall cost of this part is similar to a single full-width exponentiation,
// regardless of S.
let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
// s = div^(2^S - 1)
let s = (0..5).fold(*div, |d: F, i| sqr(d, 1 << i) * d);
// t == div^(2^(S+1) - 1)
let t = s.square() * div;
// TODO: replace this with an addition chain.
let w = ff::Field::pow_vartime(&(t * num), &F::T_MINUS1_OVER2) * s;
// v == u^((T-1)/2)
let v = w * div;
// uv = u * v
let uv = w * num;
self.sqrt_common(num, div, &uv, &v)
}
/// Same as sqrt_ratio but given num, div, v = u^((T-1)/2), and uv = u * v as input.
///
/// The choice of root from sqrt is unspecified.
fn sqrt_common(&self, num: &F, div: &F, uv: &F, v: &F) -> (Choice, F) {
let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
let inv = |x: F| self.inv[self.hasher.hash(&x)] as usize;
let x3 = *uv * v;
let x2 = sqr(x3, 8);
let x1 = sqr(x2, 8);
let x0 = sqr(x1, 8);
// i = 0, 1
let mut t_ = inv(x0); // = t >> 16
// 1 == x0 * ROOT_OF_UNITY^(t_ << 24)
assert!(t_ < 0x100);
let alpha = x1 * self.g2[t_];
// i = 2
t_ += inv(alpha) << 8; // = t >> 8
// 1 == x1 * ROOT_OF_UNITY^(t_ << 16)
assert!(t_ < 0x10000);
let alpha = x2 * self.g1[t_ & 0xFF] * self.g2[t_ >> 8];
// i = 3
t_ += inv(alpha) << 16; // = t
// 1 == x2 * ROOT_OF_UNITY^(t_ << 8)
assert!(t_ < 0x1000000);
let alpha = x3 * self.g0[t_ & 0xFF] * self.g1[(t_ >> 8) & 0xFF] * self.g2[t_ >> 16];
t_ += inv(alpha) << 24; // = t << 1
// 1 == x3 * ROOT_OF_UNITY^t_
t_ = (t_ + 1) >> 1;
assert!(t_ <= 0x80000000);
let res = *uv
* self.g0[t_ & 0xFF]
* self.g1[(t_ >> 8) & 0xFF]
* self.g2[(t_ >> 16) & 0xFF]
* self.g3[t_ >> 24];
let sqdiv = res.square() * div;
let is_square = (sqdiv - num).ct_is_zero();
let is_nonsquare = (sqdiv - F::ROOT_OF_UNITY * num).ct_is_zero();
assert!(bool::from(
num.ct_is_zero() | div.ct_is_zero() | (is_square ^ is_nonsquare)
));
(is_square, res)
}
}
/// Compute a + b + carry, returning the result and the new carry over.

View File

@ -636,6 +636,10 @@ impl ff::PrimeField for Fp {
}
}
lazy_static! {
static ref FP_TABLES: SqrtTables<Fp> = SqrtTables::new(0x11BE, 1098);
}
impl FieldExt for Fp {
const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
const ROOT_OF_UNITY_INV: Self = Fp::from_raw([
@ -671,14 +675,8 @@ impl FieldExt for Fp {
0x12ccca834acdba71,
]);
const HASH_XOR: u32 = 0x11BE;
const HASH_MOD: usize = 1098;
fn get_tables() -> &'static SqrtTables<Self> {
lazy_static! {
static ref FP_TABLES: SqrtTables<Fp> = SqrtTables::init();
}
&FP_TABLES
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
FP_TABLES.sqrt_ratio(num, div)
}
fn ct_is_zero(&self) -> Choice {

View File

@ -636,6 +636,10 @@ impl ff::PrimeField for Fq {
}
}
lazy_static! {
static ref FQ_TABLES: SqrtTables<Fq> = SqrtTables::new(0x116A9E, 1206);
}
impl FieldExt for Fq {
const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
const ROOT_OF_UNITY_INV: Self = Fq::from_raw([
@ -671,14 +675,8 @@ impl FieldExt for Fq {
0x06819a58283e528e,
]);
const HASH_XOR: u32 = 0x116A9E;
const HASH_MOD: usize = 1206;
fn get_tables() -> &'static SqrtTables<Self> {
lazy_static! {
static ref FQ_TABLES: SqrtTables<Fq> = SqrtTables::init();
}
&FQ_TABLES
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
FQ_TABLES.sqrt_ratio(num, div)
}
fn ct_is_zero(&self) -> Choice {