pasta_curves/src/arithmetic/fields.rs

421 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! This module contains the `Field` abstraction that allows us to write
//! code that generalizes over a pair of fields.
use core::mem::size_of;
use static_assertions::const_assert;
use subtle::{Choice, ConditionallySelectable, CtOption};
use super::Group;
use core::assert;
#[cfg(feature = "sqrt-table")]
use alloc::{boxed::Box, vec::Vec};
#[cfg(feature = "sqrt-table")]
use core::{convert::TryInto, marker::PhantomData};
const_assert!(size_of::<usize>() >= 4);
/// A trait that exposes additional operations related to calculating square roots of
/// prime-order finite fields.
pub trait SqrtRatio: ff::PrimeField {
/// The value $(T-1)/2$ such that $2^S \cdot T = p - 1$ with $T$ odd.
const T_MINUS1_OVER2: [u64; 4];
/// Raise this field element to the power [`Self::T_MINUS1_OVER2`].
///
/// Field implementations may override this to use an efficient addition chain.
fn pow_by_t_minus1_over2(&self) -> Self {
ff::Field::pow_vartime(self, &Self::T_MINUS1_OVER2)
}
/// Gets the lower 32 bits of this field element when expressed
/// canonically.
fn get_lower_32(&self) -> u32;
/// Computes:
///
/// - $(\textsf{true}, \sqrt{\textsf{num}/\textsf{div}})$, if $\textsf{num}$ and
/// $\textsf{div}$ are nonzero and $\textsf{num}/\textsf{div}$ is a square in the
/// field;
/// - $(\textsf{true}, 0)$, if $\textsf{num}$ is zero;
/// - $(\textsf{false}, 0)$, if $\textsf{num}$ is nonzero and $\textsf{div}$ is zero;
/// - $(\textsf{false}, \sqrt{G_S \cdot \textsf{num}/\textsf{div}})$, if
/// $\textsf{num}$ and $\textsf{div}$ are nonzero and $\textsf{num}/\textsf{div}$ is
/// a nonsquare in the field;
///
/// where $G_S$ is a non-square.
///
/// For `pasta_curves`, $G_S$ is currently [`ff::PrimeField::root_of_unity`], a
/// generator of the order $2^S$ subgroup. Users of this crate should not rely on this
/// generator being fixed; it may be changed in future crate versions to simplify the
/// implementation of the SSWU hash-to-curve algorithm.
///
/// The choice of root from sqrt is unspecified.
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
// General implementation:
//
// a = num * inv0(div)
// = { 0 if div is zero
// { num/div otherwise
//
// b = G_S * a
// = { 0 if div is zero
// { G_S*num/div otherwise
//
// Since G_S is non-square, a and b are either both zero (and both square), or
// only one of them is square. We can therefore choose the square root to return
// based on whether a is square, but for the boolean output we need to handle the
// num != 0 && div == 0 case specifically.
let a = div.invert().unwrap_or_else(Self::zero) * num;
let b = a * Self::root_of_unity();
let sqrt_a = a.sqrt();
let sqrt_b = b.sqrt();
let num_is_zero = num.is_zero();
let div_is_zero = div.is_zero();
let is_square = sqrt_a.is_some();
let is_nonsquare = sqrt_b.is_some();
assert!(bool::from(
num_is_zero | div_is_zero | (is_square ^ is_nonsquare)
));
(
is_square & !(!num_is_zero & div_is_zero),
CtOption::conditional_select(&sqrt_b, &sqrt_a, is_square).unwrap(),
)
}
/// Equivalent to `Self::sqrt_ratio(self, one())`.
fn sqrt_alt(&self) -> (Choice, Self) {
Self::sqrt_ratio(self, &Self::one())
}
}
/// This trait is a common interface for dealing with elements of a finite
/// field.
pub trait FieldExt: SqrtRatio + From<bool> + Ord + Group<Scalar = Self> {
/// Modulus of the field written as a string for display purposes
const MODULUS: &'static str;
/// Inverse of `PrimeField::root_of_unity()`
const ROOT_OF_UNITY_INV: Self;
/// Generator of the $t-order$ multiplicative subgroup
const DELTA: Self;
/// Inverse of $2$ in the field.
const TWO_INV: Self;
/// Element of multiplicative order $3$.
const ZETA: Self;
/// Obtains a field element congruent to the integer `v`.
fn from_u128(v: u128) -> Self;
/// Obtains a field element that is congruent to the provided little endian
/// byte representation of an integer.
fn from_bytes_wide(bytes: &[u8; 64]) -> Self;
/// Exponentiates `self` by `by`, where `by` is a little-endian order
/// integer exponent.
fn pow(&self, by: &[u64; 4]) -> Self {
let mut res = Self::one();
for e in by.iter().rev() {
for i in (0..64).rev() {
res = res.square();
let mut tmp = res;
tmp *= self;
res.conditional_assign(&tmp, (((*e >> i) & 0x1) as u8).into());
}
}
res
}
/// Gets the lower 128 bits of this field element when expressed
/// canonically.
fn get_lower_128(&self) -> u128;
}
/// TonelliShanks' square-root algorithm for `p mod 16 = 1`.
///
/// https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
///
/// `tm1d2` should be set to `(t - 1) // 2`, where `t = (modulus - 1) >> F::S`.
#[cfg(not(feature = "sqrt-table"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "sqrt-table"))))]
pub(crate) fn sqrt_tonelli_shanks<F: ff::PrimeField, S: AsRef<[u64]>>(
f: &F,
tm1d2: S,
) -> CtOption<F> {
use subtle::ConstantTimeEq;
// w = self^((t - 1) // 2)
let w = f.pow_vartime(tm1d2);
let mut v = F::S;
let mut x = w * f;
let mut b = x * w;
// Initialize z as the 2^S root of unity.
let mut z = F::root_of_unity();
for max_v in (1..=F::S).rev() {
let mut k = 1;
let mut tmp = b.square();
let mut j_less_than_v: Choice = 1.into();
for j in 2..max_v {
let tmp_is_one = tmp.ct_eq(&F::one());
let squared = F::conditional_select(&tmp, &z, tmp_is_one).square();
tmp = F::conditional_select(&squared, &tmp, tmp_is_one);
let new_z = F::conditional_select(&z, &squared, tmp_is_one);
j_less_than_v &= !j.ct_eq(&v);
k = u32::conditional_select(&j, &k, tmp_is_one);
z = F::conditional_select(&z, &new_z, j_less_than_v);
}
let result = x * z;
x = F::conditional_select(&result, &x, b.ct_eq(&F::one()));
z = z.square();
b *= z;
v = k;
}
CtOption::new(
x,
(x * x).ct_eq(f), // Only return Some if it's the square root.
)
}
/// Parameters for a perfect hash function used in square root computation.
#[cfg(feature = "sqrt-table")]
#[cfg_attr(docsrs, doc(cfg(feature = "sqrt-table")))]
#[derive(Debug)]
struct SqrtHasher<F: FieldExt> {
hash_xor: u32,
hash_mod: usize,
marker: PhantomData<F>,
}
#[cfg(feature = "sqrt-table")]
impl<F: FieldExt> SqrtHasher<F> {
/// Returns a perfect hash of x for use with SqrtTables::inv.
fn hash(&self, x: &F) -> usize {
// This is just the simplest constant-time perfect hash construction that could
// possibly work. The 32 low-order bits are unique within the 2^S order subgroup,
// then the xor acts as "salt" to injectively randomize the output when taken modulo
// `hash_mod`. Since the table is small, we do not need anything more complicated.
((x.get_lower_32() ^ self.hash_xor) as usize) % self.hash_mod
}
}
/// Tables used for square root computation.
#[cfg(feature = "sqrt-table")]
#[cfg_attr(docsrs, doc(cfg(feature = "sqrt-table")))]
#[derive(Debug)]
pub struct SqrtTables<F: FieldExt> {
hasher: SqrtHasher<F>,
inv: Vec<u8>,
g0: Box<[F; 256]>,
g1: Box<[F; 256]>,
g2: Box<[F; 256]>,
g3: Box<[F; 129]>,
}
#[cfg(feature = "sqrt-table")]
impl<F: FieldExt> SqrtTables<F> {
/// Build tables given parameters for the perfect hash.
pub fn new(hash_xor: u32, hash_mod: usize) -> Self {
use alloc::vec;
let hasher = SqrtHasher {
hash_xor,
hash_mod,
marker: PhantomData,
};
let mut gtab = (0..4).scan(F::root_of_unity(), |gi, _| {
// gi == ROOT_OF_UNITY^(256^i)
let gtab_i: Vec<F> = (0..256)
.scan(F::one(), |acc, _| {
let res = *acc;
*acc *= *gi;
Some(res)
})
.collect();
*gi = gtab_i[255] * *gi;
Some(gtab_i)
});
let gtab_0 = gtab.next().unwrap();
let gtab_1 = gtab.next().unwrap();
let gtab_2 = gtab.next().unwrap();
let mut gtab_3 = gtab.next().unwrap();
assert_eq!(gtab.next(), None);
// Now invert gtab[3].
let mut inv: Vec<u8> = vec![1; hash_mod];
for (j, gtab_3_j) in gtab_3.iter().enumerate() {
let hash = hasher.hash(gtab_3_j);
// 1 is the last value to be assigned, so this ensures there are no collisions.
assert!(inv[hash] == 1);
inv[hash] = ((256 - j) & 0xFF) as u8;
}
gtab_3.truncate(129);
SqrtTables::<F> {
hasher,
inv,
g0: gtab_0.into_boxed_slice().try_into().unwrap(),
g1: gtab_1.into_boxed_slice().try_into().unwrap(),
g2: gtab_2.into_boxed_slice().try_into().unwrap(),
g3: gtab_3.into_boxed_slice().try_into().unwrap(),
}
}
/// Computes:
///
/// * (true, sqrt(num/div)), if num and div are nonzero and num/div is a square in the field;
/// * (true, 0), if num is zero;
/// * (false, 0), if num is nonzero and div is zero;
/// * (false, sqrt(ROOT_OF_UNITY * num/div)), if num and div are nonzero and num/div is a nonsquare in the field;
///
/// where ROOT_OF_UNITY is a generator of the order 2^n subgroup (and therefore a nonsquare).
///
/// The choice of root from sqrt is unspecified.
pub fn sqrt_ratio(&self, num: &F, div: &F) -> (Choice, F) {
// Based on:
// * [Sarkar2020](https://eprint.iacr.org/2020/1407)
// * [BDLSY2012](https://cr.yp.to/papers.html#ed25519)
//
// We need to calculate uv and v, where v = u^((T-1)/2), u = num/div, and p-1 = T * 2^S.
// We can rewrite as follows:
//
// v = (num/div)^((T-1)/2)
// = num^((T-1)/2) * div^(p-1 - (T-1)/2) [Fermat's Little Theorem]
// = " * div^(T * 2^S - (T-1)/2)
// = " * div^((2^(S+1) - 1)*(T-1)/2 + 2^S)
// = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S)
//
// Let w = (num * div^(2^(S+1) - 1))^((T-1)/2) * div^(2^S - 1).
// Then v = w * div, and uv = num * v / div = num * w.
//
// We calculate:
//
// s = div^(2^S - 1) using an addition chain
// t = div^(2^(S+1) - 1) = s^2 * div
// w = (num * t)^((T-1)/2) * s using another addition chain
//
// then u and uv as above. The addition chains are given in
// https://github.com/zcash/pasta/blob/master/addchain_sqrt.py .
// The overall cost of this part is similar to a single full-width exponentiation,
// regardless of S.
let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
// s = div^(2^S - 1)
let s = (0..5).fold(*div, |d: F, i| sqr(d, 1 << i) * d);
// t == div^(2^(S+1) - 1)
let t = s.square() * div;
// w = (num * t)^((T-1)/2) * s
let w = (t * num).pow_by_t_minus1_over2() * s;
// v == u^((T-1)/2)
let v = w * div;
// uv = u * v
let uv = w * num;
let res = self.sqrt_common(&uv, &v);
let sqdiv = res.square() * div;
let is_square = (sqdiv - num).is_zero();
let is_nonsquare = (sqdiv - F::root_of_unity() * num).is_zero();
assert!(bool::from(
num.is_zero() | div.is_zero() | (is_square ^ is_nonsquare)
));
(is_square, res)
}
/// Same as sqrt_ratio(u, one()) but more efficient.
pub fn sqrt_alt(&self, u: &F) -> (Choice, F) {
let v = u.pow_by_t_minus1_over2();
let uv = *u * v;
let res = self.sqrt_common(&uv, &v);
let sq = res.square();
let is_square = (sq - u).is_zero();
let is_nonsquare = (sq - F::root_of_unity() * u).is_zero();
assert!(bool::from(u.is_zero() | (is_square ^ is_nonsquare)));
(is_square, res)
}
/// Common part of sqrt_ratio and sqrt_alt: return their result given v = u^((T-1)/2) and uv = u * v.
fn sqrt_common(&self, uv: &F, v: &F) -> F {
let sqr = |x: F, i: u32| (0..i).fold(x, |x, _| x.square());
let inv = |x: F| self.inv[self.hasher.hash(&x)] as usize;
let x3 = *uv * v;
let x2 = sqr(x3, 8);
let x1 = sqr(x2, 8);
let x0 = sqr(x1, 8);
// i = 0, 1
let mut t_ = inv(x0); // = t >> 16
// 1 == x0 * ROOT_OF_UNITY^(t_ << 24)
assert!(t_ < 0x100);
let alpha = x1 * self.g2[t_];
// i = 2
t_ += inv(alpha) << 8; // = t >> 8
// 1 == x1 * ROOT_OF_UNITY^(t_ << 16)
assert!(t_ < 0x10000);
let alpha = x2 * self.g1[t_ & 0xFF] * self.g2[t_ >> 8];
// i = 3
t_ += inv(alpha) << 16; // = t
// 1 == x2 * ROOT_OF_UNITY^(t_ << 8)
assert!(t_ < 0x1000000);
let alpha = x3 * self.g0[t_ & 0xFF] * self.g1[(t_ >> 8) & 0xFF] * self.g2[t_ >> 16];
t_ += inv(alpha) << 24; // = t << 1
// 1 == x3 * ROOT_OF_UNITY^t_
t_ = (t_ + 1) >> 1;
assert!(t_ <= 0x80000000);
*uv * self.g0[t_ & 0xFF]
* self.g1[(t_ >> 8) & 0xFF]
* self.g2[(t_ >> 16) & 0xFF]
* self.g3[t_ >> 24]
}
}
/// Compute a + b + carry, returning the result and the new carry over.
#[inline(always)]
pub(crate) const fn adc(a: u64, b: u64, carry: u64) -> (u64, u64) {
let ret = (a as u128) + (b as u128) + (carry as u128);
(ret as u64, (ret >> 64) as u64)
}
/// Compute a - (b + borrow), returning the result and the new borrow.
#[inline(always)]
pub(crate) const fn sbb(a: u64, b: u64, borrow: u64) -> (u64, u64) {
let ret = (a as u128).wrapping_sub((b as u128) + ((borrow >> 63) as u128));
(ret as u64, (ret >> 64) as u64)
}
/// Compute a + (b * c) + carry, returning the result and the new carry over.
#[inline(always)]
pub(crate) const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) {
let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128);
(ret as u64, (ret >> 64) as u64)
}