use std::{ cmp, fmt, hash::{Hash, Hasher}, marker::PhantomData, ops::{Add, Mul, MulAssign, Neg, Sub}, sync::Arc, }; use ff::WithSmallOrderMulGroup; use group::ff::Field; use super::{ Basis, Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation, }; use crate::multicore; /// Returns `(chunk_size, num_chunks)` suitable for processing the given polynomial length /// in the current parallelization environment. fn get_chunk_params(poly_len: usize) -> (usize, usize) { // Check the level of parallelization we have available. let num_threads = multicore::current_num_threads(); // We scale the number of chunks by a constant factor, to ensure that if not all // threads are available, we can achieve more uniform throughput and don't end up // waiting on a couple of threads to process the last chunks. let num_chunks = num_threads * 4; // Calculate the ideal chunk size for the desired throughput. We use ceiling // division to ensure the minimum chunk size is 1. // chunk_size = ceil(poly_len / num_chunks) let chunk_size = (poly_len + num_chunks - 1) / num_chunks; // Now re-calculate num_chunks from the actual chunk size. // num_chunks = ceil(poly_len / chunk_size) let num_chunks = (poly_len + chunk_size - 1) / chunk_size; (chunk_size, num_chunks) } /// A reference to a polynomial registered with an [`Evaluator`]. #[derive(Clone, Copy)] pub(crate) struct AstLeaf { index: usize, rotation: Rotation, _evaluator: PhantomData<(E, B)>, } impl fmt::Debug for AstLeaf { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("AstLeaf") .field("index", &self.index) .field("rotation", &self.rotation) .finish() } } impl PartialEq for AstLeaf { fn eq(&self, rhs: &Self) -> bool { // We compare rotations by offset, which doesn't account for equivalent rotations. self.index.eq(&rhs.index) && self.rotation.0.eq(&rhs.rotation.0) } } impl Eq for AstLeaf {} impl Hash for AstLeaf { fn hash(&self, state: &mut H) { self.index.hash(state); self.rotation.0.hash(state); } } impl AstLeaf { /// Produces a new `AstLeaf` node corresponding to the underlying polynomial at a /// _new_ rotation. Existing rotations applied to this leaf node are ignored and the /// returned polynomial is not rotated _relative_ to the previous structure. pub(crate) fn with_rotation(&self, rotation: Rotation) -> Self { AstLeaf { index: self.index, rotation, _evaluator: PhantomData::default(), } } } /// An evaluation context for polynomial operations. /// /// This context enables us to de-duplicate queries of circuit columns (and the rotations /// they might require), by storing a list of all the underlying polynomials involved in /// any query (which are almost certainly column polynomials). We use the context like so: /// /// - We register each underlying polynomial with the evaluator, which returns a reference /// to it as a [`AstLeaf`]. /// - The references are then used to build up a [`Ast`] that represents the overall /// operations to be applied to the polynomials. /// - Finally, we call [`Evaluator::evaluate`] passing in the [`Ast`]. pub(crate) struct Evaluator { polys: Vec>, _context: E, } /// Constructs a new `Evaluator`. /// /// The `context` parameter is used to provide type safety for evaluators. It ensures that /// an evaluator will only be used to evaluate [`Ast`]s containing [`AstLeaf`]s obtained /// from itself. It should be set to the empty closure `|| {}`, because anonymous closures /// all have unique types. pub(crate) fn new_evaluator(context: E) -> Evaluator { Evaluator { polys: vec![], _context: context, } } impl Evaluator { /// Registers the given polynomial for use in this evaluation context. /// /// This API treats each registered polynomial as unique, even if the same polynomial /// is added multiple times. pub(crate) fn register_poly(&mut self, poly: Polynomial) -> AstLeaf { let index = self.polys.len(); self.polys.push(poly); AstLeaf { index, rotation: Rotation::cur(), _evaluator: PhantomData::default(), } } /// Evaluates the given polynomial operation against this context. pub(crate) fn evaluate( &self, ast: &Ast, domain: &EvaluationDomain, ) -> Polynomial where E: Copy + Send + Sync, F: WithSmallOrderMulGroup<3>, B: BasisOps, { // We're working in a single basis, so all polynomials are the same length. let poly_len = self.polys.first().unwrap().len(); let (chunk_size, _num_chunks) = get_chunk_params(poly_len); struct AstContext<'a, F: Field, B: Basis> { domain: &'a EvaluationDomain, poly_len: usize, chunk_size: usize, chunk_index: usize, polys: &'a [Polynomial], } fn recurse, B: BasisOps>( ast: &Ast, ctx: &AstContext<'_, F, B>, ) -> Vec { match ast { Ast::Poly(leaf) => B::get_chunk_of_rotated( ctx.domain, ctx.chunk_size, ctx.chunk_index, &ctx.polys[leaf.index], leaf.rotation, ), Ast::Add(a, b) => { let mut lhs = recurse(a, ctx); let rhs = recurse(b, ctx); for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) { *lhs += *rhs; } lhs } Ast::Mul(AstMul(a, b)) => { let mut lhs = recurse(a, ctx); let rhs = recurse(b, ctx); for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) { *lhs *= *rhs; } lhs } Ast::Scale(a, scalar) => { let mut lhs = recurse(a, ctx); for lhs in lhs.iter_mut() { *lhs *= scalar; } lhs } Ast::DistributePowers(terms, base) => terms.iter().fold( B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, F::ZERO), |mut acc, term| { let term = recurse(term, ctx); for (acc, term) in acc.iter_mut().zip(term) { *acc *= base; *acc += term; } acc }, ), Ast::LinearTerm(scalar) => B::linear_term( ctx.domain, ctx.poly_len, ctx.chunk_size, ctx.chunk_index, *scalar, ), Ast::ConstantTerm(scalar) => { B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, *scalar) } } } // Apply `ast` to each chunk in parallel, writing the result into an output // polynomial. let mut result = B::empty_poly(domain); multicore::scope(|scope| { for (chunk_index, out) in result.chunks_mut(chunk_size).enumerate() { scope.spawn(move |_| { let ctx = AstContext { domain, poly_len, chunk_size, chunk_index, polys: &self.polys, }; out.copy_from_slice(&recurse(ast, &ctx)); }); } }); result } } /// Struct representing the [`Ast::Mul`] case. /// /// This struct exists to make the internals of this case private so that we don't /// accidentally construct this case directly, because it can only be implemented for the /// [`ExtendedLagrangeCoeff`] basis. #[derive(Clone)] pub(crate) struct AstMul(Arc>, Arc>); impl fmt::Debug for AstMul { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("AstMul") .field(&self.0) .field(&self.1) .finish() } } /// A polynomial operation backed by an [`Evaluator`]. #[derive(Clone)] pub(crate) enum Ast { Poly(AstLeaf), Add(Arc>, Arc>), Mul(AstMul), Scale(Arc>, F), /// Represents a linear combination of a vector of nodes and the powers of a /// field element, where the nodes are ordered from highest to lowest degree /// terms. DistributePowers(Arc>>, F), /// The degree-1 term of a polynomial. /// /// The field element is the coefficient of the term in the standard basis, not the /// coefficient basis. LinearTerm(F), /// The degree-0 term of a polynomial. /// /// The field element is the same in both the standard and evaluation bases. ConstantTerm(F), } impl Ast { pub fn distribute_powers>(i: I, base: F) -> Self { Ast::DistributePowers(Arc::new(i.into_iter().collect()), base) } } impl fmt::Debug for Ast { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Poly(leaf) => f.debug_tuple("Poly").field(leaf).finish(), Self::Add(lhs, rhs) => f.debug_tuple("Add").field(lhs).field(rhs).finish(), Self::Mul(x) => f.debug_tuple("Mul").field(x).finish(), Self::Scale(base, scalar) => f.debug_tuple("Scale").field(base).field(scalar).finish(), Self::DistributePowers(terms, base) => f .debug_tuple("DistributePowers") .field(terms) .field(base) .finish(), Self::LinearTerm(x) => f.debug_tuple("LinearTerm").field(x).finish(), Self::ConstantTerm(x) => f.debug_tuple("ConstantTerm").field(x).finish(), } } } impl From> for Ast { fn from(leaf: AstLeaf) -> Self { Ast::Poly(leaf) } } impl Ast { pub(crate) fn one() -> Self { Self::ConstantTerm(F::ONE) } } impl Neg for Ast { type Output = Ast; fn neg(self) -> Self::Output { Ast::Scale(Arc::new(self), -F::ONE) } } impl Neg for &Ast { type Output = Ast; fn neg(self) -> Self::Output { -(self.clone()) } } impl Add for Ast { type Output = Ast; fn add(self, other: Self) -> Self::Output { Ast::Add(Arc::new(self), Arc::new(other)) } } impl<'a, E: Clone, F: Field, B: Basis> Add<&'a Ast> for &'a Ast { type Output = Ast; fn add(self, other: &'a Ast) -> Self::Output { self.clone() + other.clone() } } impl Add> for Ast { type Output = Ast; fn add(self, other: AstLeaf) -> Self::Output { Ast::Add(Arc::new(self), Arc::new(other.into())) } } impl Sub for Ast { type Output = Ast; fn sub(self, other: Self) -> Self::Output { self + (-other) } } impl<'a, E: Clone, F: Field, B: Basis> Sub<&'a Ast> for &'a Ast { type Output = Ast; fn sub(self, other: &'a Ast) -> Self::Output { self + &(-other) } } impl Sub> for Ast { type Output = Ast; fn sub(self, other: AstLeaf) -> Self::Output { self + (-Ast::from(other)) } } impl Mul for Ast { type Output = Ast; fn mul(self, other: Self) -> Self::Output { Ast::Mul(AstMul(Arc::new(self), Arc::new(other))) } } impl<'a, E: Clone, F: Field> Mul<&'a Ast> for &'a Ast { type Output = Ast; fn mul(self, other: &'a Ast) -> Self::Output { self.clone() * other.clone() } } impl Mul> for Ast { type Output = Ast; fn mul(self, other: AstLeaf) -> Self::Output { Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into()))) } } impl Mul for Ast { type Output = Ast; fn mul(self, other: Self) -> Self::Output { Ast::Mul(AstMul(Arc::new(self), Arc::new(other))) } } impl<'a, E: Clone, F: Field> Mul<&'a Ast> for &'a Ast { type Output = Ast; fn mul(self, other: &'a Ast) -> Self::Output { self.clone() * other.clone() } } impl Mul> for Ast { type Output = Ast; fn mul(self, other: AstLeaf) -> Self::Output { Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into()))) } } impl Mul for Ast { type Output = Ast; fn mul(self, other: F) -> Self::Output { Ast::Scale(Arc::new(self), other) } } impl Mul for &Ast { type Output = Ast; fn mul(self, other: F) -> Self::Output { Ast::Scale(Arc::new(self.clone()), other) } } impl MulAssign for Ast { fn mul_assign(&mut self, rhs: Self) { *self = self.clone().mul(rhs) } } /// Operations which can be performed over a given basis. pub(crate) trait BasisOps: Basis { fn empty_poly>( domain: &EvaluationDomain, ) -> Polynomial; fn constant_term( poly_len: usize, chunk_size: usize, chunk_index: usize, scalar: F, ) -> Vec; fn linear_term>( domain: &EvaluationDomain, poly_len: usize, chunk_size: usize, chunk_index: usize, scalar: F, ) -> Vec; fn get_chunk_of_rotated>( domain: &EvaluationDomain, chunk_size: usize, chunk_index: usize, poly: &Polynomial, rotation: Rotation, ) -> Vec; } impl BasisOps for Coeff { fn empty_poly>( domain: &EvaluationDomain, ) -> Polynomial { domain.empty_coeff() } fn constant_term( poly_len: usize, chunk_size: usize, chunk_index: usize, scalar: F, ) -> Vec { let mut chunk = vec![F::ZERO; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]; if chunk_index == 0 { chunk[0] = scalar; } chunk } fn linear_term>( _: &EvaluationDomain, poly_len: usize, chunk_size: usize, chunk_index: usize, scalar: F, ) -> Vec { let mut chunk = vec![F::ZERO; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]; // If the chunk size is 1 (e.g. if we have a small k and many threads), then the // linear coefficient is the second chunk. Otherwise, the chunk size is greater // than one, and the linear coefficient is the second element of the first chunk. // Note that we check against the original chunk size, not the potentially-short // actual size of the current chunk, because we want to know whether the size of // the previous chunk was 1. if chunk_size == 1 && chunk_index == 1 { chunk[0] = scalar; } else if chunk_index == 0 { chunk[1] = scalar; } chunk } fn get_chunk_of_rotated>( _: &EvaluationDomain, _: usize, _: usize, _: &Polynomial, _: Rotation, ) -> Vec { panic!("Can't rotate polynomials in the standard basis") } } impl BasisOps for LagrangeCoeff { fn empty_poly>( domain: &EvaluationDomain, ) -> Polynomial { domain.empty_lagrange() } fn constant_term( poly_len: usize, chunk_size: usize, chunk_index: usize, scalar: F, ) -> Vec { vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)] } fn linear_term>( domain: &EvaluationDomain, poly_len: usize, chunk_size: usize, chunk_index: usize, scalar: F, ) -> Vec { // Take every power of omega within the chunk, and multiply by scalar. let omega = domain.get_omega(); let start = chunk_size * chunk_index; (0..cmp::min(chunk_size, poly_len - start)) .scan(omega.pow_vartime([start as u64]) * scalar, |acc, _| { let ret = *acc; *acc *= omega; Some(ret) }) .collect() } fn get_chunk_of_rotated>( _: &EvaluationDomain, chunk_size: usize, chunk_index: usize, poly: &Polynomial, rotation: Rotation, ) -> Vec { poly.get_chunk_of_rotated(rotation, chunk_size, chunk_index) } } impl BasisOps for ExtendedLagrangeCoeff { fn empty_poly>( domain: &EvaluationDomain, ) -> Polynomial { domain.empty_extended() } fn constant_term( poly_len: usize, chunk_size: usize, chunk_index: usize, scalar: F, ) -> Vec { vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)] } fn linear_term>( domain: &EvaluationDomain, poly_len: usize, chunk_size: usize, chunk_index: usize, scalar: F, ) -> Vec { // Take every power of the extended omega within the chunk, and multiply by scalar. let omega = domain.get_extended_omega(); let start = chunk_size * chunk_index; (0..cmp::min(chunk_size, poly_len - start)) .scan( omega.pow_vartime([start as u64]) * F::ZETA * scalar, |acc, _| { let ret = *acc; *acc *= omega; Some(ret) }, ) .collect() } fn get_chunk_of_rotated>( domain: &EvaluationDomain, chunk_size: usize, chunk_index: usize, poly: &Polynomial, rotation: Rotation, ) -> Vec { domain.get_chunk_of_rotated_extended(poly, rotation, chunk_size, chunk_index) } } #[cfg(test)] mod tests { use group::ff::Field; use pasta_curves::pallas; use super::{get_chunk_params, new_evaluator, Ast, BasisOps, Evaluator}; use crate::poly::{Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff}; #[test] fn short_chunk_regression_test() { // Pick the smallest polynomial length that is guaranteed to produce a short chunk // on this machine. let k = match (1..16) .map(|k| (k, get_chunk_params(1 << k))) .find(|(k, (chunk_size, num_chunks))| (1 << k) < chunk_size * num_chunks) .map(|(k, _)| k) { Some(k) => k, None => { // We are on a machine with a power-of-two number of threads, and cannot // trigger the bug. eprintln!( "can't find a polynomial length for short_chunk_regression_test; skipping" ); return; } }; eprintln!("Testing short-chunk regression with k = {}", k); fn test_case( k: u32, mut evaluator: Evaluator, ) { // Instantiate the evaluator with a trivial polynomial. let domain = EvaluationDomain::new(1, k); evaluator.register_poly(B::empty_poly(&domain)); // With the bug present, these will panic. let _ = evaluator.evaluate(&Ast::ConstantTerm(pallas::Base::ZERO), &domain); let _ = evaluator.evaluate(&Ast::LinearTerm(pallas::Base::ZERO), &domain); } test_case(k, new_evaluator::<_, _, Coeff>(|| {})); test_case(k, new_evaluator::<_, _, LagrangeCoeff>(|| {})); test_case(k, new_evaluator::<_, _, ExtendedLagrangeCoeff>(|| {})); } }