ff: Rework BitIterator to work with both u8 and u64 limb sizes

This enables BitIterator to be used with both the byte encoding and limb
representation of scalars.
This commit is contained in:
Jack Grigg 2020-03-28 12:02:32 +13:00
parent fd79de5408
commit 232f0a50b8
11 changed files with 80 additions and 36 deletions

View File

@ -313,12 +313,12 @@ pub fn field_into_allocated_bits_le<E: ScalarEngine, CS: ConstraintSystem<E>, F:
// Deconstruct in big-endian bit order
let values = match value {
Some(ref value) => {
let mut field_char = BitIterator::new(F::char());
let mut field_char = BitIterator::<u64, _>::new(F::char());
let mut tmp = Vec::with_capacity(F::NUM_BITS as usize);
let mut found_one = false;
for b in BitIterator::new(value.into_repr()) {
for b in BitIterator::<u64, _>::new(value.into_repr()) {
// Skip leading bits
found_one |= field_char.next().unwrap();
if !found_one {

View File

@ -103,7 +103,9 @@ impl<E: ScalarEngine> AllocatedNum<E> {
// We want to ensure that the bit representation of a is
// less than or equal to r - 1.
let mut a = self.value.map(|e| BitIterator::new(e.into_repr()));
let mut a = self
.value
.map(|e| BitIterator::<u64, _>::new(e.into_repr()));
let mut b = E::Fr::char();
b.sub_noborrow(&1.into());
@ -115,7 +117,7 @@ impl<E: ScalarEngine> AllocatedNum<E> {
let mut found_one = false;
let mut i = 0;
for b in BitIterator::new(b) {
for b in BitIterator::<u64, _>::new(b) {
let a_bit = a.as_mut().map(|e| e.next().unwrap());
// Skip over unset bits at the beginning
@ -558,7 +560,7 @@ mod test {
assert!(cs.is_satisfied());
for (b, a) in BitIterator::new(r.into_repr())
for (b, a) in BitIterator::<u64, _>::new(r.into_repr())
.skip(1)
.zip(bits.iter().rev())
{

View File

@ -13,6 +13,7 @@ extern crate std;
pub use ff_derive::*;
use core::fmt;
use core::marker::PhantomData;
use core::ops::{Add, AddAssign, BitAnd, Mul, MulAssign, Neg, Shr, Sub, SubAssign};
use rand_core::RngCore;
#[cfg(feature = "std")]
@ -338,20 +339,25 @@ pub trait ScalarEngine: Sized + 'static + Clone {
}
#[derive(Debug)]
pub struct BitIterator<E> {
pub struct BitIterator<T, E: AsRef<[T]>> {
t: E,
n: usize,
_limb: PhantomData<T>,
}
impl<E: AsRef<[u64]>> BitIterator<E> {
impl<E: AsRef<[u64]>> BitIterator<u64, E> {
pub fn new(t: E) -> Self {
let n = t.as_ref().len() * 64;
BitIterator { t, n }
BitIterator {
t,
n,
_limb: PhantomData::default(),
}
}
}
impl<E: AsRef<[u64]>> Iterator for BitIterator<E> {
impl<E: AsRef<[u64]>> Iterator for BitIterator<u64, E> {
type Item = bool;
fn next(&mut self) -> Option<bool> {
@ -367,9 +373,37 @@ impl<E: AsRef<[u64]>> Iterator for BitIterator<E> {
}
}
impl<E: AsRef<[u8]>> BitIterator<u8, E> {
pub fn new(t: E) -> Self {
let n = t.as_ref().len() * 8;
BitIterator {
t,
n,
_limb: PhantomData::default(),
}
}
}
impl<E: AsRef<[u8]>> Iterator for BitIterator<u8, E> {
type Item = bool;
fn next(&mut self) -> Option<bool> {
if self.n == 0 {
None
} else {
self.n -= 1;
let part = self.n / 8;
let bit = self.n - (8 * part);
Some(self.t.as_ref()[part] & (1 << bit) > 0)
}
}
}
#[test]
fn test_bit_iterator() {
let mut a = BitIterator::new([0xa953_d79b_83f6_ab59, 0x6dea_2059_e200_bd39]);
let mut a = BitIterator::<u64, _>::new([0xa953_d79b_83f6_ab59, 0x6dea_2059_e200_bd39]);
let expected = "01101101111010100010000001011001111000100000000010111101001110011010100101010011110101111001101110000011111101101010101101011001";
for e in expected.chars() {
@ -380,7 +414,7 @@ fn test_bit_iterator() {
let expected = "1010010101111110101010000101101011101000011101110101001000011001100100100011011010001011011011010001011011101100110100111011010010110001000011110100110001100110011101101000101100011100100100100100001010011101010111110011101011000011101000111011011101011001";
let mut a = BitIterator::new([
let mut a = BitIterator::<u64, _>::new([
0x429d_5f3a_c3a3_b759,
0xb10f_4c66_768b_1c92,
0x9236_8b6d_16ec_d3b4,

View File

@ -81,7 +81,18 @@ macro_rules! curve_impl {
}
impl $affine {
fn mul_bits<S: AsRef<[u64]>>(&self, bits: BitIterator<S>) -> $projective {
fn mul_bits_u64<S: AsRef<[u64]>>(&self, bits: BitIterator<u64, S>) -> $projective {
let mut res = $projective::zero();
for i in bits {
res.double();
if i {
res.add_assign(self)
}
}
res
}
fn mul_bits_u8<S: AsRef<[u8]>>(&self, bits: BitIterator<u8, S>) -> $projective {
let mut res = $projective::zero();
for i in bits {
res.double();
@ -172,8 +183,8 @@ macro_rules! curve_impl {
}
fn mul<S: Into<<Self::Scalar as PrimeField>::Repr>>(&self, by: S) -> $projective {
let bits = BitIterator::new(by.into());
self.mul_bits(bits)
let bits = BitIterator::<u64, _>::new(by.into());
self.mul_bits_u64(bits)
}
fn into_projective(&self) -> $projective {
@ -655,7 +666,7 @@ macro_rules! curve_impl {
let mut found_one = false;
for i in BitIterator::new(other.into()) {
for i in BitIterator::<u64, _>::new(other.into()) {
if found_one {
res.double();
} else {
@ -992,8 +1003,8 @@ pub mod g1 {
impl G1Affine {
fn scale_by_cofactor(&self) -> G1 {
// G1 cofactor = (x - 1)^2 / 3 = 76329603384216526031706109802092473003
let cofactor = BitIterator::new([0x8c00aaab0000aaab, 0x396c8c005555e156]);
self.mul_bits(cofactor)
let cofactor = BitIterator::<u64, _>::new([0x8c00aaab0000aaab, 0x396c8c005555e156]);
self.mul_bits_u64(cofactor)
}
fn get_generator() -> Self {
@ -1714,7 +1725,7 @@ pub mod g2 {
fn scale_by_cofactor(&self) -> G2 {
// G2 cofactor = (x^8 - 4 x^7 + 5 x^6) - (4 x^4 + 6 x^3 - 4 x^2 - 4 x + 13) // 9
// 0x5d543a95414e7f1091d50792876a202cd91de4547085abaa68a205b2e5a7ddfa628f1cb4d9e82ef21537e293a6691ae1616ec6e786f0c70cf1c38e31c7238e5
let cofactor = BitIterator::new([
let cofactor = BitIterator::<u64, _>::new([
0xcf1c38e31c7238e5,
0x1616ec6e786f0c70,
0x21537e293a6691ae,
@ -1724,7 +1735,7 @@ pub mod g2 {
0x91d50792876a202,
0x5d543a95414e7f1,
]);
self.mul_bits(cofactor)
self.mul_bits_u64(cofactor)
}
fn perform_pairing(&self, other: &G1Affine) -> Fq12 {

View File

@ -82,7 +82,7 @@ impl Engine for Bls12 {
let mut f = Fq12::one();
let mut found_one = false;
for i in BitIterator::new(&[BLS_X >> 1]) {
for i in BitIterator::<u64, _>::new(&[BLS_X >> 1]) {
if !found_one {
found_one = i;
continue;
@ -324,7 +324,7 @@ impl G2Prepared {
let mut r: G2 = q.into();
let mut found_one = false;
for i in BitIterator::new([BLS_X >> 1]) {
for i in BitIterator::<u64, _>::new([BLS_X >> 1]) {
if !found_one {
found_one = i;
continue;

View File

@ -468,7 +468,7 @@ impl<E: JubjubEngine, Subgroup> Point<E, Subgroup> {
let mut res = Self::zero();
for b in BitIterator::new(scalar.into()) {
for b in BitIterator::<u64, _>::new(scalar.into()) {
res = res.double(params);
if b {

View File

@ -1,4 +1,3 @@
use byteorder::{ByteOrder, LittleEndian};
use ff::{
adc, mac_with_carry, sbb, BitIterator, Field, PowVartime, PrimeField, PrimeFieldDecodingError,
PrimeFieldRepr, SqrtField,
@ -721,7 +720,7 @@ impl Fs {
self.reduce();
}
fn mul_bits<S: AsRef<[u64]>>(&self, bits: BitIterator<S>) -> Self {
fn mul_bits<S: AsRef<[u8]>>(&self, bits: BitIterator<u8, S>) -> Self {
let mut res = Self::zero();
for bit in bits {
res = res.double();
@ -741,9 +740,7 @@ impl ToUniform for Fs {
/// Random Oracle output.
fn to_uniform(digest: &[u8]) -> Self {
assert_eq!(digest.len(), 64);
let mut repr: [u64; 8] = [0; 8];
LittleEndian::read_u64_into(digest, &mut repr);
Self::one().mul_bits(BitIterator::new(repr))
Self::one().mul_bits(BitIterator::<u8, _>::new(digest))
}
}

View File

@ -304,7 +304,7 @@ impl<E: JubjubEngine, Subgroup> Point<E, Subgroup> {
let mut res = Self::zero();
for b in BitIterator::new(scalar.into()) {
for b in BitIterator::<u64, _>::new(scalar.into()) {
res = res.double(params);
if b {

View File

@ -21,7 +21,7 @@ pub const SAPLING_COMMITMENT_TREE_DEPTH: usize = 32;
pub fn merkle_hash(depth: usize, lhs: &FrRepr, rhs: &FrRepr) -> FrRepr {
let lhs = {
let mut tmp = [false; 256];
for (a, b) in tmp.iter_mut().rev().zip(BitIterator::new(lhs)) {
for (a, b) in tmp.iter_mut().rev().zip(BitIterator::<u64, _>::new(lhs)) {
*a = b;
}
tmp
@ -29,7 +29,7 @@ pub fn merkle_hash(depth: usize, lhs: &FrRepr, rhs: &FrRepr) -> FrRepr {
let rhs = {
let mut tmp = [false; 256];
for (a, b) in tmp.iter_mut().rev().zip(BitIterator::new(rhs)) {
for (a, b) in tmp.iter_mut().rev().zip(BitIterator::<u64, _>::new(rhs)) {
*a = b;
}
tmp

View File

@ -769,7 +769,7 @@ mod test {
let q = p.mul(s, params);
let (x1, y1) = q.to_xy();
let mut s_bits = BitIterator::new(s.into_repr()).collect::<Vec<_>>();
let mut s_bits = BitIterator::<u64, _>::new(s.into_repr()).collect::<Vec<_>>();
s_bits.reverse();
s_bits.truncate(Fs::NUM_BITS as usize);
@ -822,7 +822,7 @@ mod test {
y: num_y0,
};
let mut s_bits = BitIterator::new(s.into_repr()).collect::<Vec<_>>();
let mut s_bits = BitIterator::<u64, _>::new(s.into_repr()).collect::<Vec<_>>();
s_bits.reverse();
s_bits.truncate(Fs::NUM_BITS as usize);

View File

@ -615,8 +615,8 @@ fn test_input_circuit_with_bls12_381() {
::std::mem::swap(&mut lhs, &mut rhs);
}
let mut lhs: Vec<bool> = BitIterator::new(lhs.into_repr()).collect();
let mut rhs: Vec<bool> = BitIterator::new(rhs.into_repr()).collect();
let mut lhs: Vec<bool> = BitIterator::<u64, _>::new(lhs.into_repr()).collect();
let mut rhs: Vec<bool> = BitIterator::<u64, _>::new(rhs.into_repr()).collect();
lhs.reverse();
rhs.reverse();
@ -799,8 +799,8 @@ fn test_input_circuit_with_bls12_381_external_test_vectors() {
::std::mem::swap(&mut lhs, &mut rhs);
}
let mut lhs: Vec<bool> = BitIterator::new(lhs.into_repr()).collect();
let mut rhs: Vec<bool> = BitIterator::new(rhs.into_repr()).collect();
let mut lhs: Vec<bool> = BitIterator::<u64, _>::new(lhs.into_repr()).collect();
let mut rhs: Vec<bool> = BitIterator::<u64, _>::new(rhs.into_repr()).collect();
lhs.reverse();
rhs.reverse();