mirror of https://github.com/zcash/halo2.git
Avoid exposing implementation details of the square root implementation.
Signed-off-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
parent
e13ee2c8ff
commit
227025b7b3
|
@ -5,6 +5,7 @@ use core::mem::size_of;
|
|||
use static_assertions::const_assert;
|
||||
use std::assert;
|
||||
use std::convert::TryInto;
|
||||
use std::marker::PhantomData;
|
||||
use subtle::{Choice, ConstantTimeEq, CtOption};
|
||||
|
||||
use super::Group;
|
||||
|
@ -41,14 +42,17 @@ pub trait FieldExt:
|
|||
/// Element of multiplicative order $3$.
|
||||
const ZETA: Self;
|
||||
|
||||
/// XOR parameter of the perfect hash function used for SqrtTables.
|
||||
const HASH_XOR: u32;
|
||||
|
||||
/// Modulus of the perfect hash function used for SqrtTables.
|
||||
const HASH_MOD: usize;
|
||||
|
||||
/// Tables for square root computation.
|
||||
fn get_tables() -> &'static SqrtTables<Self>;
|
||||
/// Computes:
|
||||
///
|
||||
/// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field;
|
||||
/// * (true, 0), if num is zero;
|
||||
/// * (false, 0), if num is nonzero and div is zero;
|
||||
/// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field;
|
||||
///
|
||||
/// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare).
|
||||
///
|
||||
/// The choice of root from sqrt is unspecified.
|
||||
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self);
|
||||
|
||||
/// This computes a random element of the field using system randomness.
|
||||
fn rand() -> Self {
|
||||
|
@ -76,119 +80,6 @@ pub trait FieldExt:
|
|||
/// byte representation of an integer.
|
||||
fn from_bytes_wide(bytes: &[u8; 64]) -> Self;
|
||||
|
||||
/// Computes:
|
||||
///
|
||||
/// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field;
|
||||
/// * (true, 0), if num is zero;
|
||||
/// * (false, 0), if num is nonzero and div is zero;
|
||||
/// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field;
|
||||
///
|
||||
/// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare).
|
||||
///
|
||||
/// The choice of root from sqrt is unspecified.
|
||||
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
|
||||
// Based on:
|
||||
// * [Sarkar2020](https://eprint.iacr.org/2020/1407)
|
||||
// * [BDLSY2012](https://cr.yp.to/papers.html#ed25519)
|
||||
//
|
||||
// We need to calculate uv and v, where v = u^((m-1)/2), u = num/div, and p-1 = T * 2^S.
|
||||
// We can rewrite as follows:
|
||||
//
|
||||
// v = (num/div)^((T-1)/2)
|
||||
// = num^((T-1)/2) * div^(p-1 - (T-1)/2) [Fermat's Little Theorem]
|
||||
// = " * div^(T * 2^S - (T-1)/2)
|
||||
// = " * div^((2^(S+1) - 1)*(T-1)/2 + 2^S)
|
||||
// = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S)
|
||||
//
|
||||
// Let w = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S - 1).
|
||||
// Then v = w * div, and uv = num * v / div = num * w.
|
||||
//
|
||||
// We calculate:
|
||||
//
|
||||
// s = div^(2^S - 1) using an addition chain
|
||||
// t = div^(2^(S+1) - 1) = s^2 * div
|
||||
// w = (num * t)^((T-1)/2) * s using another addition chain
|
||||
//
|
||||
// then u and uv as above. The addition chains are given in
|
||||
// https://github.com/zcash/pasta/blob/master/addchain_sqrt.py .
|
||||
// The overall cost of this part is similar to a single full-width exponentiation,
|
||||
// regardless of S.
|
||||
|
||||
let sqr = |x: Self, i: u32| (0..i).fold(x, |x, _| x.square());
|
||||
|
||||
// s = div^(2^S - 1)
|
||||
let s = (0..5).fold(*div, |d: Self, i| sqr(d, 1 << i) * d);
|
||||
|
||||
// t == div^(2^(S+1) - 1)
|
||||
let t = s.square() * div;
|
||||
|
||||
// TODO: replace this with an addition chain.
|
||||
let w = ff::Field::pow_vartime(&(t * num), &Self::T_MINUS1_OVER2) * s;
|
||||
|
||||
// v == u^((T-1)/2)
|
||||
let v = w * div;
|
||||
|
||||
// uv = u * v
|
||||
let uv = w * num;
|
||||
|
||||
Self::sqrt_common(num, div, &uv, &v)
|
||||
}
|
||||
|
||||
/// Same as sqrt_ratio but given num, div, v = u^((T-1)/2), and uv = u * v as input.
|
||||
///
|
||||
/// The choice of root from sqrt is unspecified.
|
||||
fn sqrt_common(num: &Self, div: &Self, uv: &Self, v: &Self) -> (Choice, Self) {
|
||||
let tab = Self::get_tables();
|
||||
let sqr = |x: Self, i: u32| (0..i).fold(x, |x, _| x.square());
|
||||
|
||||
let x3 = *uv * v;
|
||||
let x2 = sqr(x3, 8);
|
||||
let x1 = sqr(x2, 8);
|
||||
let x0 = sqr(x1, 8);
|
||||
|
||||
// i = 0, 1
|
||||
let mut t_: usize = tab.inv[x0.hash()] as usize; // = t >> 16
|
||||
// 1 == x0 * ROOT_OF_UNITY^(t_ << 24)
|
||||
assert!(t_ < 0x100);
|
||||
let alpha = x1 * tab.g2[t_];
|
||||
|
||||
// i = 2
|
||||
t_ += (tab.inv[alpha.hash()] as usize) << 8; // = t >> 8
|
||||
// 1 == x1 * ROOT_OF_UNITY^(t_ << 16)
|
||||
assert!(t_ < 0x10000);
|
||||
let alpha = x2 * tab.g1[t_ & 0xFF] * tab.g2[t_ >> 8];
|
||||
|
||||
// i = 3
|
||||
t_ += (tab.inv[alpha.hash()] as usize) << 16; // = t
|
||||
// 1 == x2 * ROOT_OF_UNITY^(t_ << 8)
|
||||
assert!(t_ < 0x1000000);
|
||||
let alpha = x3 * tab.g0[t_ & 0xFF] * tab.g1[(t_ >> 8) & 0xFF] * tab.g2[t_ >> 16];
|
||||
|
||||
t_ += (tab.inv[alpha.hash()] as usize) << 24; // = t << 1
|
||||
// 1 == x3 * ROOT_OF_UNITY^t_
|
||||
t_ = (t_ + 1) >> 1;
|
||||
assert!(t_ <= 0x80000000);
|
||||
let res = *uv
|
||||
* tab.g0[t_ & 0xFF]
|
||||
* tab.g1[(t_ >> 8) & 0xFF]
|
||||
* tab.g2[(t_ >> 16) & 0xFF]
|
||||
* tab.g3[t_ >> 24];
|
||||
|
||||
let sqdiv = res.square() * div;
|
||||
let is_square = (sqdiv - num).ct_is_zero();
|
||||
let is_nonsquare = (sqdiv - Self::ROOT_OF_UNITY * num).ct_is_zero();
|
||||
assert!(bool::from(
|
||||
num.ct_is_zero() | div.ct_is_zero() | (is_square ^ is_nonsquare)
|
||||
));
|
||||
|
||||
(is_square, res)
|
||||
}
|
||||
|
||||
/// Returns a perfect hash of this element for use with inv.
|
||||
fn hash(&self) -> usize {
|
||||
((self.get_lower_32() ^ Self::HASH_XOR) as usize) % Self::HASH_MOD
|
||||
}
|
||||
|
||||
/// Exponentiates `self` by `by`, where `by` is a little-endian order
|
||||
/// integer exponent.
|
||||
fn pow(&self, by: &[u64; 4]) -> Self {
|
||||
|
@ -238,9 +129,25 @@ pub trait FieldExt:
|
|||
}
|
||||
}
|
||||
|
||||
/// Parameters for a perfect hash function used in square root computation.
|
||||
#[derive(Debug)]
|
||||
struct SqrtHasher<F: FieldExt> {
|
||||
hash_xor: u32,
|
||||
hash_mod: usize,
|
||||
marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: FieldExt> SqrtHasher<F> {
|
||||
/// Returns a perfect hash of x for use with SqrtTables::inv.
|
||||
fn hash(&self, x: &F) -> usize {
|
||||
((x.get_lower_32() ^ self.hash_xor) as usize) % self.hash_mod
|
||||
}
|
||||
}
|
||||
|
||||
/// Tables used for square root computation.
|
||||
#[derive(Debug)]
|
||||
pub struct SqrtTables<F: FieldExt> {
|
||||
hasher: SqrtHasher<F>,
|
||||
inv: Vec<u8>,
|
||||
g0: [F; 256],
|
||||
g1: [F; 256],
|
||||
|
@ -250,7 +157,13 @@ pub struct SqrtTables<F: FieldExt> {
|
|||
|
||||
impl<F: FieldExt> SqrtTables<F> {
|
||||
/// Build tables given parameters for the perfect hash.
|
||||
pub fn init() -> Self {
|
||||
pub fn new(hash_xor: u32, hash_mod: usize) -> Self {
|
||||
let hasher = SqrtHasher {
|
||||
hash_xor,
|
||||
hash_mod,
|
||||
marker: PhantomData,
|
||||
};
|
||||
|
||||
let gtab: Vec<Vec<F>> = (0..4)
|
||||
.scan(F::ROOT_OF_UNITY, |gi, _| {
|
||||
// gi == ROOT_OF_UNITY^(256^i)
|
||||
|
@ -267,15 +180,16 @@ impl<F: FieldExt> SqrtTables<F> {
|
|||
.collect();
|
||||
|
||||
// Now invert gtab[3].
|
||||
let mut inv: Vec<u8> = vec![1; F::HASH_MOD];
|
||||
let mut inv: Vec<u8> = vec![1; hash_mod];
|
||||
for j in 0..256 {
|
||||
let hash = gtab[3][j].hash();
|
||||
let hash = hasher.hash(>ab[3][j]);
|
||||
// 1 is the last value to be assigned, so this ensures there are no collisions.
|
||||
assert!(inv[hash] == 1);
|
||||
inv[hash] = ((256 - j) & 0xFF) as u8;
|
||||
}
|
||||
|
||||
SqrtTables::<F> {
|
||||
hasher,
|
||||
inv,
|
||||
g0: gtab[0][..].try_into().unwrap(),
|
||||
g1: gtab[1][..].try_into().unwrap(),
|
||||
|
@ -283,6 +197,114 @@ impl<F: FieldExt> SqrtTables<F> {
|
|||
g3: gtab[3][0..129].try_into().unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes:
|
||||
///
|
||||
/// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field;
|
||||
/// * (true, 0), if num is zero;
|
||||
/// * (false, 0), if num is nonzero and div is zero;
|
||||
/// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field;
|
||||
///
|
||||
/// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare).
|
||||
///
|
||||
/// The choice of root from sqrt is unspecified.
|
||||
pub fn sqrt_ratio(&self, num: &F, div: &F) -> (Choice, F) {
|
||||
// Based on:
|
||||
// * [Sarkar2020](https://eprint.iacr.org/2020/1407)
|
||||
// * [BDLSY2012](https://cr.yp.to/papers.html#ed25519)
|
||||
//
|
||||
// We need to calculate uv and v, where v = u^((m-1)/2), u = num/div, and p-1 = T * 2^S.
|
||||
// We can rewrite as follows:
|
||||
//
|
||||
// v = (num/div)^((T-1)/2)
|
||||
// = num^((T-1)/2) * div^(p-1 - (T-1)/2) [Fermat's Little Theorem]
|
||||
// = " * div^(T * 2^S - (T-1)/2)
|
||||
// = " * div^((2^(S+1) - 1)*(T-1)/2 + 2^S)
|
||||
// = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S)
|
||||
//
|
||||
// Let w = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S - 1).
|
||||
// Then v = w * div, and uv = num * v / div = num * w.
|
||||
//
|
||||
// We calculate:
|
||||
//
|
||||
// s = div^(2^S - 1) using an addition chain
|
||||
// t = div^(2^(S+1) - 1) = s^2 * div
|
||||
// w = (num * t)^((T-1)/2) * s using another addition chain
|
||||
//
|
||||
// then u and uv as above. The addition chains are given in
|
||||
// https://github.com/zcash/pasta/blob/master/addchain_sqrt.py .
|
||||
// The overall cost of this part is similar to a single full-width exponentiation,
|
||||
// regardless of S.
|
||||
|
||||
let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
|
||||
|
||||
// s = div^(2^S - 1)
|
||||
let s = (0..5).fold(*div, |d: F, i| sqr(d, 1 << i) * d);
|
||||
|
||||
// t == div^(2^(S+1) - 1)
|
||||
let t = s.square() * div;
|
||||
|
||||
// TODO: replace this with an addition chain.
|
||||
let w = ff::Field::pow_vartime(&(t * num), &F::T_MINUS1_OVER2) * s;
|
||||
|
||||
// v == u^((T-1)/2)
|
||||
let v = w * div;
|
||||
|
||||
// uv = u * v
|
||||
let uv = w * num;
|
||||
|
||||
self.sqrt_common(num, div, &uv, &v)
|
||||
}
|
||||
|
||||
/// Same as sqrt_ratio but given num, div, v = u^((T-1)/2), and uv = u * v as input.
|
||||
///
|
||||
/// The choice of root from sqrt is unspecified.
|
||||
fn sqrt_common(&self, num: &F, div: &F, uv: &F, v: &F) -> (Choice, F) {
|
||||
let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
|
||||
let inv = |x: F| self.inv[self.hasher.hash(&x)] as usize;
|
||||
|
||||
let x3 = *uv * v;
|
||||
let x2 = sqr(x3, 8);
|
||||
let x1 = sqr(x2, 8);
|
||||
let x0 = sqr(x1, 8);
|
||||
|
||||
// i = 0, 1
|
||||
let mut t_ = inv(x0); // = t >> 16
|
||||
// 1 == x0 * ROOT_OF_UNITY^(t_ << 24)
|
||||
assert!(t_ < 0x100);
|
||||
let alpha = x1 * self.g2[t_];
|
||||
|
||||
// i = 2
|
||||
t_ += inv(alpha) << 8; // = t >> 8
|
||||
// 1 == x1 * ROOT_OF_UNITY^(t_ << 16)
|
||||
assert!(t_ < 0x10000);
|
||||
let alpha = x2 * self.g1[t_ & 0xFF] * self.g2[t_ >> 8];
|
||||
|
||||
// i = 3
|
||||
t_ += inv(alpha) << 16; // = t
|
||||
// 1 == x2 * ROOT_OF_UNITY^(t_ << 8)
|
||||
assert!(t_ < 0x1000000);
|
||||
let alpha = x3 * self.g0[t_ & 0xFF] * self.g1[(t_ >> 8) & 0xFF] * self.g2[t_ >> 16];
|
||||
|
||||
t_ += inv(alpha) << 24; // = t << 1
|
||||
// 1 == x3 * ROOT_OF_UNITY^t_
|
||||
t_ = (t_ + 1) >> 1;
|
||||
assert!(t_ <= 0x80000000);
|
||||
let res = *uv
|
||||
* self.g0[t_ & 0xFF]
|
||||
* self.g1[(t_ >> 8) & 0xFF]
|
||||
* self.g2[(t_ >> 16) & 0xFF]
|
||||
* self.g3[t_ >> 24];
|
||||
|
||||
let sqdiv = res.square() * div;
|
||||
let is_square = (sqdiv - num).ct_is_zero();
|
||||
let is_nonsquare = (sqdiv - F::ROOT_OF_UNITY * num).ct_is_zero();
|
||||
assert!(bool::from(
|
||||
num.ct_is_zero() | div.ct_is_zero() | (is_square ^ is_nonsquare)
|
||||
));
|
||||
|
||||
(is_square, res)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute a + b + carry, returning the result and the new carry over.
|
||||
|
|
|
@ -636,6 +636,10 @@ impl ff::PrimeField for Fp {
|
|||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref FP_TABLES: SqrtTables<Fp> = SqrtTables::new(0x11BE, 1098);
|
||||
}
|
||||
|
||||
impl FieldExt for Fp {
|
||||
const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
|
||||
const ROOT_OF_UNITY_INV: Self = Fp::from_raw([
|
||||
|
@ -671,14 +675,8 @@ impl FieldExt for Fp {
|
|||
0x12ccca834acdba71,
|
||||
]);
|
||||
|
||||
const HASH_XOR: u32 = 0x11BE;
|
||||
const HASH_MOD: usize = 1098;
|
||||
|
||||
fn get_tables() -> &'static SqrtTables<Self> {
|
||||
lazy_static! {
|
||||
static ref FP_TABLES: SqrtTables<Fp> = SqrtTables::init();
|
||||
}
|
||||
&FP_TABLES
|
||||
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
|
||||
FP_TABLES.sqrt_ratio(num, div)
|
||||
}
|
||||
|
||||
fn ct_is_zero(&self) -> Choice {
|
||||
|
|
|
@ -636,6 +636,10 @@ impl ff::PrimeField for Fq {
|
|||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref FQ_TABLES: SqrtTables<Fq> = SqrtTables::new(0x116A9E, 1206);
|
||||
}
|
||||
|
||||
impl FieldExt for Fq {
|
||||
const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
|
||||
const ROOT_OF_UNITY_INV: Self = Fq::from_raw([
|
||||
|
@ -671,14 +675,8 @@ impl FieldExt for Fq {
|
|||
0x06819a58283e528e,
|
||||
]);
|
||||
|
||||
const HASH_XOR: u32 = 0x116A9E;
|
||||
const HASH_MOD: usize = 1206;
|
||||
|
||||
fn get_tables() -> &'static SqrtTables<Self> {
|
||||
lazy_static! {
|
||||
static ref FQ_TABLES: SqrtTables<Fq> = SqrtTables::init();
|
||||
}
|
||||
&FQ_TABLES
|
||||
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
|
||||
FQ_TABLES.sqrt_ratio(num, div)
|
||||
}
|
||||
|
||||
fn ct_is_zero(&self) -> Choice {
|
||||
|
|
Loading…
Reference in New Issue