Use `ff::PrimeField::{from_repr, to_repr}` instead of `FieldExt`

This commit is contained in:
Jack Grigg 2021-09-30 22:49:59 +01:00
parent 9693065a00
commit 0e6b0344f5
2 changed files with 30 additions and 20 deletions

View File

@ -3,12 +3,15 @@
use super::multicore; use super::multicore;
pub use ff::Field; pub use ff::Field;
use group::{ff::BatchInvert, Group as _}; use group::{
ff::{BatchInvert, PrimeField},
Group as _,
};
pub use pasta_curves::arithmetic::*; pub use pasta_curves::arithmetic::*;
fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<[u8; 32]> = coeffs.iter().map(|a| a.to_bytes()).collect(); let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
let c = if bases.len() < 4 { let c = if bases.len() < 4 {
1 1
@ -18,7 +21,7 @@ fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut
(f64::from(bases.len() as u32)).ln().ceil() as usize (f64::from(bases.len() as u32)).ln().ceil() as usize
}; };
fn get_at(segment: usize, c: usize, bytes: &[u8; 32]) -> usize { fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
let skip_bits = segment * c; let skip_bits = segment * c;
let skip_bytes = skip_bits / 8; let skip_bytes = skip_bits / 8;
@ -27,7 +30,7 @@ fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut
} }
let mut v = [0; 8]; let mut v = [0; 8];
for (v, o) in v.iter_mut().zip(bytes[skip_bytes..].iter()) { for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
*v = *o; *v = *o;
} }
@ -79,7 +82,7 @@ fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut
let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1]; let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
for (coeff, base) in coeffs.iter().zip(bases.iter()) { for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let coeff = get_at(current_segment, c, coeff); let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
if coeff != 0 { if coeff != 0 {
buckets[coeff - 1].add_assign(base); buckets[coeff - 1].add_assign(base);
} }
@ -100,7 +103,7 @@ fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut
/// Performs a small multi-exponentiation operation. /// Performs a small multi-exponentiation operation.
/// Uses the double-and-add algorithm with doublings shared across points. /// Uses the double-and-add algorithm with doublings shared across points.
pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
let coeffs: Vec<[u8; 32]> = coeffs.iter().map(|a| a.to_bytes()).collect(); let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
let mut acc = C::Curve::identity(); let mut acc = C::Curve::identity();
// for byte idx // for byte idx
@ -110,7 +113,7 @@ pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::C
acc = acc.double(); acc = acc.double();
// for each coeff // for each coeff
for coeff_idx in 0..coeffs.len() { for coeff_idx in 0..coeffs.len() {
let byte = coeffs[coeff_idx][byte_idx]; let byte = coeffs[coeff_idx].as_ref()[byte_idx];
if ((byte >> bit_idx) & 1) != 0 { if ((byte >> bit_idx) & 1) != 0 {
acc += bases[coeff_idx]; acc += bases[coeff_idx];
} }

View File

@ -2,6 +2,7 @@
//! transcripts. //! transcripts.
use blake2b_simd::{Params as Blake2bParams, State as Blake2bState}; use blake2b_simd::{Params as Blake2bParams, State as Blake2bState};
use group::ff::PrimeField;
use std::convert::TryInto; use std::convert::TryInto;
use crate::arithmetic::{Coordinates, CurveAffine, FieldExt}; use crate::arithmetic::{Coordinates, CurveAffine, FieldExt};
@ -97,9 +98,9 @@ impl<R: Read, C: CurveAffine> TranscriptRead<C, Challenge255<C>>
} }
fn read_scalar(&mut self) -> io::Result<C::Scalar> { fn read_scalar(&mut self) -> io::Result<C::Scalar> {
let mut data = [0u8; 32]; let mut data = <C::Scalar as PrimeField>::Repr::default();
self.reader.read_exact(&mut data)?; self.reader.read_exact(data.as_mut())?;
let scalar: C::Scalar = Option::from(C::Scalar::from_bytes(&data)).ok_or_else(|| { let scalar: C::Scalar = Option::from(C::Scalar::from_repr(data)).ok_or_else(|| {
io::Error::new( io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"invalid field element encoding in proof", "invalid field element encoding in proof",
@ -129,15 +130,15 @@ impl<R: Read, C: CurveAffine> Transcript<C, Challenge255<C>>
"cannot write points at infinity to the transcript", "cannot write points at infinity to the transcript",
) )
})?; })?;
self.state.update(&coords.x().to_bytes()); self.state.update(coords.x().to_repr().as_ref());
self.state.update(&coords.y().to_bytes()); self.state.update(coords.y().to_repr().as_ref());
Ok(()) Ok(())
} }
fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> {
self.state.update(&[BLAKE2B_PREFIX_SCALAR]); self.state.update(&[BLAKE2B_PREFIX_SCALAR]);
self.state.update(&scalar.to_bytes()); self.state.update(scalar.to_repr().as_ref());
Ok(()) Ok(())
} }
@ -181,8 +182,8 @@ impl<W: Write, C: CurveAffine> TranscriptWrite<C, Challenge255<C>>
} }
fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> {
self.common_scalar(scalar)?; self.common_scalar(scalar)?;
let data = scalar.to_bytes(); let data = scalar.to_repr();
self.writer.write_all(&data[..]) self.writer.write_all(data.as_ref())
} }
} }
@ -204,15 +205,15 @@ impl<W: Write, C: CurveAffine> Transcript<C, Challenge255<C>>
"cannot write points at infinity to the transcript", "cannot write points at infinity to the transcript",
) )
})?; })?;
self.state.update(&coords.x().to_bytes()); self.state.update(coords.x().to_repr().as_ref());
self.state.update(&coords.y().to_bytes()); self.state.update(coords.y().to_repr().as_ref());
Ok(()) Ok(())
} }
fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> {
self.state.update(&[BLAKE2B_PREFIX_SCALAR]); self.state.update(&[BLAKE2B_PREFIX_SCALAR]);
self.state.update(&scalar.to_bytes()); self.state.update(scalar.to_repr().as_ref());
Ok(()) Ok(())
} }
@ -277,12 +278,18 @@ impl<C: CurveAffine> EncodedChallenge<C> for Challenge255<C> {
fn new(challenge_input: &[u8; 64]) -> Self { fn new(challenge_input: &[u8; 64]) -> Self {
Challenge255( Challenge255(
C::Scalar::from_bytes_wide(challenge_input).to_bytes(), C::Scalar::from_bytes_wide(challenge_input)
.to_repr()
.as_ref()
.try_into()
.expect("Scalar fits into 256 bits"),
PhantomData, PhantomData,
) )
} }
fn get_scalar(&self) -> C::Scalar { fn get_scalar(&self) -> C::Scalar {
C::Scalar::from_bytes(&self.0).unwrap() let mut repr = <C::Scalar as PrimeField>::Repr::default();
repr.as_mut().copy_from_slice(&self.0);
C::Scalar::from_repr(repr).unwrap()
} }
} }