Add Tonelli-Shanks sqrt for 1 mod 16 primes.

This commit is contained in:
Sean Bowe 2017-06-26 23:22:41 -06:00
parent bbbd397b80
commit 9aceb63e7e
2 changed files with 106 additions and 3 deletions

View File

@ -16,3 +16,4 @@ syn = "0.11"
quote = "0.3"
num-bigint = "0.1"
num-traits = "0.1"
num-integer = "0.1"

View File

@ -7,12 +7,14 @@ extern crate quote;
extern crate num_bigint;
extern crate num_traits;
extern crate num_integer;
use num_integer::Integer;
use num_traits::{Zero, One, ToPrimitive};
use num_bigint::BigUint;
use std::str::FromStr;
#[proc_macro_derive(PrimeField, attributes(PrimeFieldModulus))]
#[proc_macro_derive(PrimeField, attributes(PrimeFieldModulus, PrimeFieldGenerator))]
pub fn prime_field(
input: proc_macro::TokenStream
) -> proc_macro::TokenStream
@ -32,6 +34,11 @@ pub fn prime_field(
.expect("Please supply a PrimeFieldModulus attribute")
.parse().expect("PrimeFieldModulus should be a number");
// We may be provided with a generator of p - 1 order. It is required that this generator be quadratic
// nonresidue.
let generator: Option<BigUint> = fetch_attr("PrimeFieldGenerator", &ast.attrs)
.map(|i| i.parse().expect("PrimeFieldGenerator must be a number."));
// The arithmetic in this library only works if the modulus*2 is smaller than the backing
// representation. Compute the number of limbs we need.
let mut limbs = 1;
@ -47,7 +54,7 @@ pub fn prime_field(
let mut gen = quote::Tokens::new();
gen.append(prime_field_repr_impl(&repr_ident, limbs));
gen.append(prime_field_constants_and_sqrt(&ast.ident, &repr_ident, modulus, limbs));
gen.append(prime_field_constants_and_sqrt(&ast.ident, &repr_ident, modulus, limbs, generator));
gen.append(prime_field_impl(&ast.ident, &repr_ident, limbs));
// Return the generated impl
@ -277,11 +284,45 @@ fn biguint_num_bits(
bits
}
fn exp(
base: BigUint,
exp: &BigUint,
modulus: &BigUint
) -> BigUint
{
let mut ret = BigUint::one();
for i in exp.to_bytes_be()
.into_iter()
.flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd()))
{
ret = (&ret * &ret) % modulus;
if i {
ret = (ret * &base) % modulus;
}
}
ret
}
#[test]
fn test_exp() {
assert_eq!(
exp(
BigUint::from_str("4398572349857239485729348572983472345").unwrap(),
&BigUint::from_str("5489673498567349856734895").unwrap(),
&BigUint::from_str("52435875175126190479447740508185965837690552500527637822603658699938581184513").unwrap()
),
BigUint::from_str("4371221214068404307866768905142520595925044802278091865033317963560480051536").unwrap()
);
}
fn prime_field_constants_and_sqrt(
name: &syn::Ident,
repr: &syn::Ident,
modulus: BigUint,
limbs: usize
limbs: usize,
generator: Option<BigUint>
) -> quote::Tokens
{
let modulus_num_bits = biguint_num_bits(modulus.clone());
@ -290,6 +331,14 @@ fn prime_field_constants_and_sqrt(
// Compute R = 2**(64 * limbs) mod m
let r = (BigUint::one() << (limbs * 64)) % &modulus;
// modulus - 1 = 2^s * t
let mut s = 0;
let mut t = &modulus - BigUint::from_str("1").unwrap();
while t.is_even() {
t = t >> 1;
s += 1;
}
let sqrt_impl =
if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
let mod_minus_3_over_4 = biguint_to_u64_vec((&modulus - BigUint::from_str("3").unwrap()) >> 2);
@ -318,6 +367,59 @@ fn prime_field_constants_and_sqrt(
}
}
}
} else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() {
let mod_minus_1_over_2 = biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1);
let generator = generator.expect("PrimeFieldGenerator attribute should be provided; should be a generator of order p - 1 and quadratic nonresidue.");
let root_of_unity = biguint_to_u64_vec((exp(generator.clone(), &t, &modulus) * &r) % &modulus);
let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1);
let t = biguint_to_u64_vec(t.clone());
quote!{
impl ::ff::SqrtField for #name {
fn sqrt(&self) -> Option<Self> {
// Tonelli-Shank's algorithm for q mod 16 = 1
// https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
if self.is_zero() {
return Some(*self);
}
if self.pow(#mod_minus_1_over_2) != Self::one() {
None
} else {
let mut c = #name(#repr(#root_of_unity));
let mut r = self.pow(#t_plus_1_over_2);
let mut t = self.pow(#t);
let mut m = #s;
while t != Self::one() {
let mut i = 1;
{
let mut t2i = t;
t2i.square();
loop {
if t2i == Self::one() {
break;
}
t2i.square();
i += 1;
}
}
for _ in 0..(m - i - 1) {
c.square();
}
r.mul_assign(&c);
c.square();
t.mul_assign(&c);
m = i;
}
Some(r)
}
}
}
}
} else {
quote!{}
};