From 806748fbc4029007763a2f75eea99279bbffc4d7 Mon Sep 17 00:00:00 2001 From: Daira Hopwood Date: Sun, 10 Jan 2021 15:22:24 +0000 Subject: [PATCH] Use addition chains for powering by (T-1)/2. Signed-off-by: Daira Hopwood --- src/arithmetic/fields.rs | 10 ++++++++-- src/pasta/fields/fp.rs | 33 +++++++++++++++++++++++++++++++++ src/pasta/fields/fq.rs | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 2 deletions(-) diff --git a/src/arithmetic/fields.rs b/src/arithmetic/fields.rs index 07fcb880..4ad8fd37 100644 --- a/src/arithmetic/fields.rs +++ b/src/arithmetic/fields.rs @@ -103,6 +103,12 @@ pub trait FieldExt: /// canonically. fn get_lower_32(&self) -> u32; + /// Raise this field element to the power T_MINUS1_OVER2. + /// Field implementations may override this to use an efficient addition chain. + fn pow_by_t_minus1_over2(&self) -> Self { + ff::Field::pow_vartime(&self, &Self::T_MINUS1_OVER2) + } + /// Performs a batch inversion using Montgomery's trick, returns the product /// of every inverse. Zero inputs are ignored. fn batch_invert(v: &mut [Self]) -> Self { @@ -244,8 +250,8 @@ impl SqrtTables { // t == div^(2^(S+1) - 1) let t = s.square() * div; - // TODO: replace this with an addition chain. - let w = ff::Field::pow_vartime(&(t * num), &F::T_MINUS1_OVER2) * s; + // w = (num * t)^((T-1)/2) * s + let w = (t * num).pow_by_t_minus1_over2() * s; // v == u^((T-1)/2) let v = w * div; diff --git a/src/pasta/fields/fp.rs b/src/pasta/fields/fp.rs index 0857de10..2f038f34 100644 --- a/src/pasta/fields/fp.rs +++ b/src/pasta/fields/fp.rs @@ -762,6 +762,39 @@ impl FieldExt for Fp { tmp.0[0] as u32 } + + fn pow_by_t_minus1_over2(&self) -> Self { + let sqr = |x: Fp, i: u32| (0..i).fold(x, |x, _| x.square()); + + let r10 = self.square(); + let r11 = r10 * self; + let r110 = r11.square(); + let r111 = r110 * self; + let r1001 = r111 * r10; + let r1101 = r111 * r110; + let ra = sqr(*self, 129) * self; + let rb = sqr(ra, 7) * r1001; + let rc = sqr(rb, 7) * r1101; + let rd = sqr(rc, 4) * r11; + let re = sqr(rd, 6) * r111; + let rf = sqr(re, 3) * r111; + let rg = sqr(rf, 10) * r1001; + let rh = sqr(rg, 5) * r1001; + let ri = sqr(rh, 4) * r1001; + let rj = sqr(ri, 3) * r111; + let rk = sqr(rj, 4) * r1001; + let rl = sqr(rk, 5) * r11; + let rm = sqr(rl, 4) * r111; + let rn = sqr(rm, 4) * r11; + let ro = sqr(rn, 6) * r1001; + let rp = sqr(ro, 5) * r1101; + let rq = sqr(rp, 4) * r11; + let rr = sqr(rq, 7) * r111; + let rs = sqr(rr, 3) * r11; + let rt = rs.square(); + //assert!(rt == ff::Field::pow_vartime(&self, &Fp::T_MINUS1_OVER2)); + rt + } } #[cfg(test)] diff --git a/src/pasta/fields/fq.rs b/src/pasta/fields/fq.rs index 9d80a81c..df228881 100644 --- a/src/pasta/fields/fq.rs +++ b/src/pasta/fields/fq.rs @@ -762,6 +762,39 @@ impl FieldExt for Fq { tmp.0[0] as u32 } + + fn pow_by_t_minus1_over2(&self) -> Self { + let sqr = |x: Fq, i: u32| (0..i).fold(x, |x, _| x.square()); + + let s10 = self.square(); + let s11 = s10 * self; + let s111 = s11.square() * self; + let s1001 = s111 * s10; + let s1011 = s1001 * s10; + let s1101 = s1011 * s10; + let sa = sqr(*self, 129) * self; + let sb = sqr(sa, 7) * s1001; + let sc = sqr(sb, 7) * s1101; + let sd = sqr(sc, 4) * s11; + let se = sqr(sd, 6) * s111; + let sf = sqr(se, 3) * s111; + let sg = sqr(sf, 10) * s1001; + let sh = sqr(sg, 4) * s1001; + let si = sqr(sh, 5) * s1001; + let sj = sqr(si, 5) * s1001; + let sk = sqr(sj, 3) * s1001; + let sl = sqr(sk, 4) * s1011; + let sm = sqr(sl, 4) * s1011; + let sn = sqr(sm, 5) * s11; + let so = sqr(sn, 4) * self; + let sp = sqr(so, 5) * s11; + let sq = sqr(sp, 4) * s111; + let sr = sqr(sq, 5) * s1011; + let ss = sqr(sr, 3) * self; + let st = sqr(ss, 4); + //assert!(st == ff::Field::pow_vartime(&self, &Fq::T_MINUS1_OVER2)); + st + } } #[cfg(test)]