diff --git a/benches/fq_bench.rs b/benches/fq_bench.rs index 0b47324..ae95401 100644 --- a/benches/fq_bench.rs +++ b/benches/fq_bench.rs @@ -10,8 +10,16 @@ use jubjub::Fq; #[bench] fn bench_mul_assign(bencher: &mut Bencher) { let mut n = Fq::new([2, 2, 2, 2]); - let m = Fq::new([2, 2, 2, 2]); bencher.iter(move || { - n.mul_assign(&m); + let tmp = n; + n.mul_assign(&tmp); + }); +} + +#[bench] +fn bench_square_assign(bencher: &mut Bencher) { + let mut n = Fq::new([2, 2, 2, 2]); + bencher.iter(move || { + n.square_assign(); }); } diff --git a/src/fq.rs b/src/fq.rs index f9658ef..b86af28 100644 --- a/src/fq.rs +++ b/src/fq.rs @@ -261,11 +261,63 @@ impl Fq { res } + + /// Calculate a + (b * c) + carry, returning the least significant digit + /// and setting carry to the most significant digit. + #[inline(always)] + fn mac_with_carry(a: u64, b: u64, c: u64, carry: &mut u64) -> u64 { + let tmp = (u128::from(a)) + u128::from(b) * u128::from(c) + u128::from(*carry); + + *carry = (tmp >> 64) as u64; + + tmp as u64 + } + + /// Calculate a + b + carry, returning the sum and modifying the + /// carry value. + #[inline(always)] + fn adc(a: u64, b: u64, carry: &mut u64) -> u64 { + let tmp = u128::from(a) + u128::from(b) + u128::from(*carry); + + *carry = (tmp >> 64) as u64; + + tmp as u64 + } /// Squares this element. pub fn square_assign(&mut self) { - let tmp = *self; - self.mul_assign(&tmp); + let mut carry = 0; + let r1 = Fq::mac_with_carry(0, self.0[0], self.0[1], &mut carry); + let r2 = Fq::mac_with_carry(0, self.0[0], self.0[2], &mut carry); + let r3 = Fq::mac_with_carry(0, self.0[0], self.0[3], &mut carry); + let r4 = carry; + let mut carry = 0; + let r3 = Fq::mac_with_carry(r3, self.0[1], self.0[2], &mut carry); + let r4 = Fq::mac_with_carry(r4, self.0[1], self.0[3], &mut carry); + let r5 = carry; + let mut carry = 0; + let r5 = Fq::mac_with_carry(r5, self.0[2], self.0[3], &mut carry); + let r6 = carry; + + let r7 = r6 >> 63; + let r6 = (r6 << 1) | (r5 >> 63); + let r5 = (r5 << 1) | (r4 >> 63); + let r4 = (r4 << 1) | (r3 >> 63); + let r3 = (r3 << 1) | (r2 >> 63); + let r2 = (r2 << 1) | (r1 >> 63); + let r1 = r1 << 1; + + let mut carry = 0; + let r0 = Fq::mac_with_carry(0, self.0[0], self.0[0], &mut carry); + let r1 = Fq::adc(r1, 0, &mut carry); + let r2 = Fq::mac_with_carry(r2, self.0[1], self.0[1], &mut carry); + let r3 = Fq::adc(r3, 0, &mut carry); + let r4 = Fq::mac_with_carry(r4, self.0[2], self.0[2], &mut carry); + let r5 = Fq::adc(r5, 0, &mut carry); + let r6 = Fq::mac_with_carry(r6, self.0[3], self.0[3], &mut carry); + let r7 = Fq::adc(r7, 0, &mut carry); + + self.montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7); } /// Exponentiates `self` by `by`, where `by` is a @@ -634,3 +686,15 @@ fn test_inversion() { tmp.add_assign(&R2); } } + +#[test] +fn test_square_assign_equals_mul_assign() { + let mut n1 = Fq([2, 2, 2 ,2]); + let mut n2 = Fq([2, 2, 2 ,2]); + for _ in 1..100 { + let tmp = n1; + n1.mul_assign(&tmp); + n2.square_assign(); + assert_eq!(n1, n2); + } +}