From eaa0de2964df4769b9ff3160ab66bec54793a9f1 Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Wed, 29 Jun 2016 22:00:22 -0600 Subject: [PATCH] Arithmetic in Fp --- src/fp.rs | 273 +++++++++++++++++++++++++++++++++++++++++++++++--- src/params.rs | 2 + 2 files changed, 261 insertions(+), 14 deletions(-) diff --git a/src/fp.rs b/src/fp.rs index e1eee06..99a0262 100644 --- a/src/fp.rs +++ b/src/fp.rs @@ -1,12 +1,15 @@ use rand::Rng; use num::{BigUint, Num}; use std::ops::{Mul,Add,Sub,Neg}; +use std::cmp::{PartialEq, Eq}; use std::convert::From; +use std::fmt; use std::marker::PhantomData; pub trait PrimeFieldParams { fn modulus() -> BigUint; fn bits() -> usize; + fn name() -> &'static str; } pub struct Fp { @@ -14,20 +17,112 @@ pub struct Fp { _marker: PhantomData

} -impl Fp

{ - pub fn zero() -> Self { unimplemented!() } - pub fn one() -> Self { unimplemented!() } - pub fn random(rng: &mut R) -> Self { unimplemented!() } +impl fmt::Debug for Fp

{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}({})", P::name(), self.value) + } +} - pub fn is_zero(&self) -> bool { unimplemented!() } - pub fn inverse(&self) -> Self { unimplemented!() } - pub fn squared(&self) -> Self { unimplemented!() } - pub fn pow(&self, exp: &Fp) -> Self { unimplemented!() } - pub fn test_bit(&self, bit: usize) -> bool { unimplemented!() } +impl Fp

{ + pub fn zero() -> Self { + use num::Zero; + + Fp { + value: BigUint::zero(), + _marker: PhantomData + } + } + pub fn one() -> Self { + use num::One; + + Fp { + value: BigUint::one(), + _marker: PhantomData + } + } + pub fn random(rng: &mut R) -> Self { + use num::num_bigint::RandBigInt; + use num::Zero; + + Fp { + value: rng.gen_biguint_range(&BigUint::zero(), &P::modulus()), + _marker: PhantomData + } + } + + pub fn is_zero(&self) -> bool { + use num::Zero; + + self.value == BigUint::zero() + } + + pub fn inverse(&self) -> Self { + if self.is_zero() { + // TODO: this should likely bleed through the abstraction layers + panic!("cannot get the multiplicative inverse of zero") + } else { + let mut res = Self::one(); + + let mut found_one = false; + + let exp = Self::zero() - Self::one() - Self::one(); + + for i in (0..P::bits()).rev() { + if found_one { + res = res.squared(); + } + + if exp.test_bit(i) { + found_one = true; + res = self * &res; + } + } + + res + } + } + pub fn squared(&self) -> Self { + self * self + } + pub fn pow(&self, exp: &Fp) -> Self { + let mut res = Self::one(); + + let mut found_one = false; + + for i in (0..P2::bits()).rev() { + if found_one { + res = res.squared(); + } + + if exp.test_bit(i) { + found_one = true; + res = res * self; + } + } + + res + } + pub fn test_bit(&self, bit: usize) -> bool { + // TODO: This is a naive approach. + use num::{One, Zero}; + + let mut b = BigUint::one(); + let two = &b + &b; + for _ in 0..bit { + b = &b + &b; + } + + (&self.value / b) % two != BigUint::zero() + } } impl<'a, P: PrimeFieldParams> From<&'a str> for Fp

{ - fn from(s: &'a str) -> Self { unimplemented!() } + fn from(s: &'a str) -> Self { + Fp { + value: BigUint::from_str_radix(s, 10).unwrap() % P::modulus(), + _marker: PhantomData + } + } } impl Clone for Fp

{ @@ -37,25 +132,66 @@ impl Clone for Fp

{ impl<'a, 'b, P: PrimeFieldParams> Add<&'b Fp

> for &'a Fp

{ type Output = Fp

; - fn add(self, other: &Fp

) -> Fp

{ unimplemented!() } + fn add(self, other: &Fp

) -> Fp

{ + let tmp = &self.value + &other.value; + if tmp >= P::modulus() { + Fp { + value: tmp - P::modulus(), + _marker: PhantomData + } + } else { + Fp { + value: tmp, + _marker: PhantomData + } + } + } } impl<'a, 'b, P: PrimeFieldParams> Sub<&'b Fp

> for &'a Fp

{ type Output = Fp

; - fn sub(self, other: &Fp

) -> Fp

{ unimplemented!() } + fn sub(self, other: &Fp

) -> Fp

{ + if other.value > self.value { + Fp { + value: (&self.value + P::modulus()) - &other.value, + _marker: PhantomData + } + } else { + Fp { + value: &self.value - &other.value, + _marker: PhantomData + } + } + } } impl<'a, 'b, P: PrimeFieldParams> Mul<&'b Fp

> for &'a Fp

{ type Output = Fp

; - fn mul(self, other: &Fp

) -> Fp

{ unimplemented!() } + fn mul(self, other: &Fp

) -> Fp

{ + Fp { + value: (&self.value * &other.value) % &P::modulus(), + _marker: PhantomData + } + } } impl<'a, P: PrimeFieldParams> Neg for &'a Fp

{ type Output = Fp

; - fn neg(self) -> Fp

{ unimplemented!() } + fn neg(self) -> Fp

{ + use num::Zero; + + Fp { + value: if self.value.is_zero() { + self.value.clone() + } else { + P::modulus() - &self.value + }, + _marker: PhantomData + } + } } impl Neg for Fp

{ @@ -66,4 +202,113 @@ impl Neg for Fp

{ } } +impl PartialEq for Fp

{ + fn eq(&self, other: &Self) -> bool { + self.value == other.value + } +} + +impl Eq for Fp

{} + forward_all_binop_to_ref_ref!(impl(P: PrimeFieldParams) Mul for Fp

, mul); + +#[cfg(test)] +mod large_field_tests { + use super::*; + use num::{BigUint, Num}; + + struct Small; + + impl PrimeFieldParams for Small { + fn modulus() -> BigUint { + BigUint::from_str_radix("21888242871839275222246405745257275088696311157297823662689037894645226208583", 10).unwrap() + } + + fn bits() -> usize { 254 } + fn name() -> &'static str { "Small" } + } + + type Ft = Fp; + + #[test] + fn bit_testing() { + let a = Ft::from("13"); + assert!(a.test_bit(0) == true); + assert!(a.test_bit(1) == false); + assert!(a.test_bit(2) == true); + assert!(a.test_bit(3) == true); + + let expected: Vec = [1,1,0,1,1,0,0,0,0,1,0,0,1,1,1,0,0,0,0,0,1,1,0,0,1,0,0,1,1] + .iter().map(|a| *a == 1).rev().collect(); + + let a = Ft::from("453624211"); + + for (i, b) in expected.into_iter().enumerate() { + assert!(a.test_bit(i) == b); + } + + let expected: Vec = [1,1,1,1,0,1,0,1,1,0,1,0,0,0,1,1,1,0,1,1,1,1,0,0,0,0,1,1,0,1,1,0,1,0,0,1,1,0,1,1,0,0,1,1,0,0,1,1,1,1,0,1,1,1,1,0,1,0,1,1,1,1,1,0,1,0,1,0,0,0,1,1,1,0,1,1,1,1,0,0,0,1,1,1,0,0,1,0,1,0,0,0,0,1,1,0,0,1,1,1,1,0,1,1,1,1,0,0,0,1,1,0,0,1,1,1,0,1,1,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,1,0,1,0,1,0,0,0,1,1,0,1,1,0,0,1,0,1,1,1,1,1,0,0,0,1,0,1,0,0,0,1,0,1,0,0,1,0,1,1,0,1,0,1,1,0,1,0,0,0,1,0,0,1,0,1,1,1,0,1,1,0,0,0,0,0,1,1,0,0,0,1,1,1,1,1,1,0,0,0,0,1,0,1,1,1,0,1,1,0,1,1,0,0,0,0,1,1,1,1,1,0,0,1,1,1,1,1,1,0,1,0,0,1,0,1,1,0,0] + .iter().map(|a| *a == 1).rev().collect(); + let a = Ft::from("13888242871869275222244405745257275088696211157297823662689037894645226208556"); + + for (i, b) in expected.into_iter().enumerate() { + assert!(a.test_bit(i) == b); + } + } +} + +#[cfg(test)] +mod small_field_tests { + use super::*; + use num::{BigUint, Num}; + + struct Small; + + impl PrimeFieldParams for Small { + fn modulus() -> BigUint { + BigUint::from_str_radix("13", 10).unwrap() + } + + fn bits() -> usize { 6 } + fn name() -> &'static str { "Small" } + } + + type Ft = Fp; + + #[test] + fn field_ops() { + fn test_field_operation Ft>(a: u64, b: u64, f: C, expected: u64) { + let af = Ft::from(format!("{}", a).as_ref()); + let bf = Ft::from(format!("{}", b).as_ref()); + let expectedf = Ft::from(format!("{}", expected).as_ref()); + + let res = f(&af, &bf); + + if res != expectedf { + panic!("res={:?} != expectedf={:?} (a={}, b={}, expected={})", res, expectedf, a, b, expected); + } + } + + const MODULO: u64 = 13; + + for a in 0..13u64 { + for b in 0..13u64 { + test_field_operation(a, b, |a,b| {a * b}, (a*b)%MODULO); + test_field_operation(a, b, |a,b| {a + b}, (a+b)%MODULO); + test_field_operation(a, b, |a,b| {a - b}, { + let mut tmp = (a as i64) - (b as i64); + if tmp < 0 { + tmp += MODULO as i64; + } + + tmp as u64 + }); + test_field_operation(a, b, |a,b| {a.pow(b)}, (a.pow(b as u32))%MODULO); + } + test_field_operation(a, 0, |a,_| {-a}, if a == 0 { 0 } else { MODULO - a }); + if a > 0 { + test_field_operation(a, 0, |a,_| {&a.inverse() * a}, 1); + } + } + } +} diff --git a/src/params.rs b/src/params.rs index c3ae735..cec7007 100644 --- a/src/params.rs +++ b/src/params.rs @@ -8,6 +8,7 @@ impl PrimeFieldParams for FrParams { BigUint::from_str_radix("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10).unwrap() } fn bits() -> usize { 254 } + fn name() -> &'static str { "Fr" } } pub struct FqParams; @@ -17,4 +18,5 @@ impl PrimeFieldParams for FqParams { BigUint::from_str_radix("21888242871839275222246405745257275088696311157297823662689037894645226208583", 10).unwrap() } fn bits() -> usize { 254 } + fn name() -> &'static str { "Fq" } }