diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 8984eb1a..a68710dc 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -3,12 +3,15 @@ use super::multicore; pub use ff::Field; -use group::{ff::BatchInvert, Group as _}; +use group::{ + ff::{BatchInvert, PrimeField}, + Group as _, +}; pub use pasta_curves::arithmetic::*; fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { - let coeffs: Vec<[u8; 32]> = coeffs.iter().map(|a| a.to_bytes()).collect(); + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); let c = if bases.len() < 4 { 1 @@ -18,7 +21,7 @@ fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut (f64::from(bases.len() as u32)).ln().ceil() as usize }; - fn get_at(segment: usize, c: usize, bytes: &[u8; 32]) -> usize { + fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { let skip_bits = segment * c; let skip_bytes = skip_bits / 8; @@ -27,7 +30,7 @@ fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut } let mut v = [0; 8]; - for (v, o) in v.iter_mut().zip(bytes[skip_bytes..].iter()) { + for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { *v = *o; } @@ -79,7 +82,7 @@ fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; for (coeff, base) in coeffs.iter().zip(bases.iter()) { - let coeff = get_at(current_segment, c, coeff); + let coeff = get_at::(current_segment, c, coeff); if coeff != 0 { buckets[coeff - 1].add_assign(base); } @@ -100,7 +103,7 @@ fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut /// Performs a small multi-exponentiation operation. /// Uses the double-and-add algorithm with doublings shared across points. pub fn small_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { - let coeffs: Vec<[u8; 32]> = coeffs.iter().map(|a| a.to_bytes()).collect(); + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); let mut acc = C::Curve::identity(); // for byte idx @@ -110,7 +113,7 @@ pub fn small_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::C acc = acc.double(); // for each coeff for coeff_idx in 0..coeffs.len() { - let byte = coeffs[coeff_idx][byte_idx]; + let byte = coeffs[coeff_idx].as_ref()[byte_idx]; if ((byte >> bit_idx) & 1) != 0 { acc += bases[coeff_idx]; } diff --git a/src/transcript.rs b/src/transcript.rs index 1869c835..0570dd59 100644 --- a/src/transcript.rs +++ b/src/transcript.rs @@ -2,6 +2,7 @@ //! transcripts. use blake2b_simd::{Params as Blake2bParams, State as Blake2bState}; +use group::ff::PrimeField; use std::convert::TryInto; use crate::arithmetic::{Coordinates, CurveAffine, FieldExt}; @@ -97,9 +98,9 @@ impl TranscriptRead> } fn read_scalar(&mut self) -> io::Result { - let mut data = [0u8; 32]; - self.reader.read_exact(&mut data)?; - let scalar: C::Scalar = Option::from(C::Scalar::from_bytes(&data)).ok_or_else(|| { + let mut data = ::Repr::default(); + self.reader.read_exact(data.as_mut())?; + let scalar: C::Scalar = Option::from(C::Scalar::from_repr(data)).ok_or_else(|| { io::Error::new( io::ErrorKind::Other, "invalid field element encoding in proof", @@ -129,15 +130,15 @@ impl Transcript> "cannot write points at infinity to the transcript", ) })?; - self.state.update(&coords.x().to_bytes()); - self.state.update(&coords.y().to_bytes()); + self.state.update(coords.x().to_repr().as_ref()); + self.state.update(coords.y().to_repr().as_ref()); Ok(()) } fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { self.state.update(&[BLAKE2B_PREFIX_SCALAR]); - self.state.update(&scalar.to_bytes()); + self.state.update(scalar.to_repr().as_ref()); Ok(()) } @@ -181,8 +182,8 @@ impl TranscriptWrite> } fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { self.common_scalar(scalar)?; - let data = scalar.to_bytes(); - self.writer.write_all(&data[..]) + let data = scalar.to_repr(); + self.writer.write_all(data.as_ref()) } } @@ -204,15 +205,15 @@ impl Transcript> "cannot write points at infinity to the transcript", ) })?; - self.state.update(&coords.x().to_bytes()); - self.state.update(&coords.y().to_bytes()); + self.state.update(coords.x().to_repr().as_ref()); + self.state.update(coords.y().to_repr().as_ref()); Ok(()) } fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { self.state.update(&[BLAKE2B_PREFIX_SCALAR]); - self.state.update(&scalar.to_bytes()); + self.state.update(scalar.to_repr().as_ref()); Ok(()) } @@ -277,12 +278,18 @@ impl EncodedChallenge for Challenge255 { fn new(challenge_input: &[u8; 64]) -> Self { Challenge255( - C::Scalar::from_bytes_wide(challenge_input).to_bytes(), + C::Scalar::from_bytes_wide(challenge_input) + .to_repr() + .as_ref() + .try_into() + .expect("Scalar fits into 256 bits"), PhantomData, ) } fn get_scalar(&self) -> C::Scalar { - C::Scalar::from_bytes(&self.0).unwrap() + let mut repr = ::Repr::default(); + repr.as_mut().copy_from_slice(&self.0); + C::Scalar::from_repr(repr).unwrap() } }