halo2/halo2_proofs/src/poly/evaluator.rs

664 lines
22 KiB
Rust

use std::{
cmp, fmt,
hash::{Hash, Hasher},
marker::PhantomData,
ops::{Add, Mul, MulAssign, Neg, Sub},
sync::Arc,
};
use ff::WithSmallOrderMulGroup;
use group::ff::Field;
use super::{
Basis, Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation,
};
use crate::multicore;
/// Returns `(chunk_size, num_chunks)` suitable for processing the given polynomial length
/// in the current parallelization environment.
fn get_chunk_params(poly_len: usize) -> (usize, usize) {
// Check the level of parallelization we have available.
let num_threads = multicore::current_num_threads();
// We scale the number of chunks by a constant factor, to ensure that if not all
// threads are available, we can achieve more uniform throughput and don't end up
// waiting on a couple of threads to process the last chunks.
let num_chunks = num_threads * 4;
// Calculate the ideal chunk size for the desired throughput. We use ceiling
// division to ensure the minimum chunk size is 1.
// chunk_size = ceil(poly_len / num_chunks)
let chunk_size = (poly_len + num_chunks - 1) / num_chunks;
// Now re-calculate num_chunks from the actual chunk size.
// num_chunks = ceil(poly_len / chunk_size)
let num_chunks = (poly_len + chunk_size - 1) / chunk_size;
(chunk_size, num_chunks)
}
/// A reference to a polynomial registered with an [`Evaluator`].
#[derive(Clone, Copy)]
pub(crate) struct AstLeaf<E, B: Basis> {
index: usize,
rotation: Rotation,
_evaluator: PhantomData<(E, B)>,
}
impl<E, B: Basis> fmt::Debug for AstLeaf<E, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AstLeaf")
.field("index", &self.index)
.field("rotation", &self.rotation)
.finish()
}
}
impl<E, B: Basis> PartialEq for AstLeaf<E, B> {
fn eq(&self, rhs: &Self) -> bool {
// We compare rotations by offset, which doesn't account for equivalent rotations.
self.index.eq(&rhs.index) && self.rotation.0.eq(&rhs.rotation.0)
}
}
impl<E, B: Basis> Eq for AstLeaf<E, B> {}
impl<E, B: Basis> Hash for AstLeaf<E, B> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.index.hash(state);
self.rotation.0.hash(state);
}
}
impl<E, B: Basis> AstLeaf<E, B> {
/// Produces a new `AstLeaf` node corresponding to the underlying polynomial at a
/// _new_ rotation. Existing rotations applied to this leaf node are ignored and the
/// returned polynomial is not rotated _relative_ to the previous structure.
pub(crate) fn with_rotation(&self, rotation: Rotation) -> Self {
AstLeaf {
index: self.index,
rotation,
_evaluator: PhantomData::default(),
}
}
}
/// An evaluation context for polynomial operations.
///
/// This context enables us to de-duplicate queries of circuit columns (and the rotations
/// they might require), by storing a list of all the underlying polynomials involved in
/// any query (which are almost certainly column polynomials). We use the context like so:
///
/// - We register each underlying polynomial with the evaluator, which returns a reference
/// to it as a [`AstLeaf`].
/// - The references are then used to build up a [`Ast`] that represents the overall
/// operations to be applied to the polynomials.
/// - Finally, we call [`Evaluator::evaluate`] passing in the [`Ast`].
pub(crate) struct Evaluator<E, F: Field, B: Basis> {
polys: Vec<Polynomial<F, B>>,
_context: E,
}
/// Constructs a new `Evaluator`.
///
/// The `context` parameter is used to provide type safety for evaluators. It ensures that
/// an evaluator will only be used to evaluate [`Ast`]s containing [`AstLeaf`]s obtained
/// from itself. It should be set to the empty closure `|| {}`, because anonymous closures
/// all have unique types.
pub(crate) fn new_evaluator<E: Fn() + Clone, F: Field, B: Basis>(context: E) -> Evaluator<E, F, B> {
Evaluator {
polys: vec![],
_context: context,
}
}
impl<E, F: Field, B: Basis> Evaluator<E, F, B> {
/// Registers the given polynomial for use in this evaluation context.
///
/// This API treats each registered polynomial as unique, even if the same polynomial
/// is added multiple times.
pub(crate) fn register_poly(&mut self, poly: Polynomial<F, B>) -> AstLeaf<E, B> {
let index = self.polys.len();
self.polys.push(poly);
AstLeaf {
index,
rotation: Rotation::cur(),
_evaluator: PhantomData::default(),
}
}
/// Evaluates the given polynomial operation against this context.
pub(crate) fn evaluate(
&self,
ast: &Ast<E, F, B>,
domain: &EvaluationDomain<F>,
) -> Polynomial<F, B>
where
E: Copy + Send + Sync,
F: WithSmallOrderMulGroup<3>,
B: BasisOps,
{
// We're working in a single basis, so all polynomials are the same length.
let poly_len = self.polys.first().unwrap().len();
let (chunk_size, _num_chunks) = get_chunk_params(poly_len);
struct AstContext<'a, F: Field, B: Basis> {
domain: &'a EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
polys: &'a [Polynomial<F, B>],
}
fn recurse<E, F: WithSmallOrderMulGroup<3>, B: BasisOps>(
ast: &Ast<E, F, B>,
ctx: &AstContext<'_, F, B>,
) -> Vec<F> {
match ast {
Ast::Poly(leaf) => B::get_chunk_of_rotated(
ctx.domain,
ctx.chunk_size,
ctx.chunk_index,
&ctx.polys[leaf.index],
leaf.rotation,
),
Ast::Add(a, b) => {
let mut lhs = recurse(a, ctx);
let rhs = recurse(b, ctx);
for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) {
*lhs += *rhs;
}
lhs
}
Ast::Mul(AstMul(a, b)) => {
let mut lhs = recurse(a, ctx);
let rhs = recurse(b, ctx);
for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) {
*lhs *= *rhs;
}
lhs
}
Ast::Scale(a, scalar) => {
let mut lhs = recurse(a, ctx);
for lhs in lhs.iter_mut() {
*lhs *= scalar;
}
lhs
}
Ast::DistributePowers(terms, base) => terms.iter().fold(
B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, F::ZERO),
|mut acc, term| {
let term = recurse(term, ctx);
for (acc, term) in acc.iter_mut().zip(term) {
*acc *= base;
*acc += term;
}
acc
},
),
Ast::LinearTerm(scalar) => B::linear_term(
ctx.domain,
ctx.poly_len,
ctx.chunk_size,
ctx.chunk_index,
*scalar,
),
Ast::ConstantTerm(scalar) => {
B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, *scalar)
}
}
}
// Apply `ast` to each chunk in parallel, writing the result into an output
// polynomial.
let mut result = B::empty_poly(domain);
multicore::scope(|scope| {
for (chunk_index, out) in result.chunks_mut(chunk_size).enumerate() {
scope.spawn(move |_| {
let ctx = AstContext {
domain,
poly_len,
chunk_size,
chunk_index,
polys: &self.polys,
};
out.copy_from_slice(&recurse(ast, &ctx));
});
}
});
result
}
}
/// Struct representing the [`Ast::Mul`] case.
///
/// This struct exists to make the internals of this case private so that we don't
/// accidentally construct this case directly, because it can only be implemented for the
/// [`ExtendedLagrangeCoeff`] basis.
#[derive(Clone)]
pub(crate) struct AstMul<E, F: Field, B: Basis>(Arc<Ast<E, F, B>>, Arc<Ast<E, F, B>>);
impl<E, F: Field, B: Basis> fmt::Debug for AstMul<E, F, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("AstMul")
.field(&self.0)
.field(&self.1)
.finish()
}
}
/// A polynomial operation backed by an [`Evaluator`].
#[derive(Clone)]
pub(crate) enum Ast<E, F: Field, B: Basis> {
Poly(AstLeaf<E, B>),
Add(Arc<Ast<E, F, B>>, Arc<Ast<E, F, B>>),
Mul(AstMul<E, F, B>),
Scale(Arc<Ast<E, F, B>>, F),
/// Represents a linear combination of a vector of nodes and the powers of a
/// field element, where the nodes are ordered from highest to lowest degree
/// terms.
DistributePowers(Arc<Vec<Ast<E, F, B>>>, F),
/// The degree-1 term of a polynomial.
///
/// The field element is the coefficient of the term in the standard basis, not the
/// coefficient basis.
LinearTerm(F),
/// The degree-0 term of a polynomial.
///
/// The field element is the same in both the standard and evaluation bases.
ConstantTerm(F),
}
impl<E, F: Field, B: Basis> Ast<E, F, B> {
pub fn distribute_powers<I: IntoIterator<Item = Self>>(i: I, base: F) -> Self {
Ast::DistributePowers(Arc::new(i.into_iter().collect()), base)
}
}
impl<E, F: Field, B: Basis> fmt::Debug for Ast<E, F, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Poly(leaf) => f.debug_tuple("Poly").field(leaf).finish(),
Self::Add(lhs, rhs) => f.debug_tuple("Add").field(lhs).field(rhs).finish(),
Self::Mul(x) => f.debug_tuple("Mul").field(x).finish(),
Self::Scale(base, scalar) => f.debug_tuple("Scale").field(base).field(scalar).finish(),
Self::DistributePowers(terms, base) => f
.debug_tuple("DistributePowers")
.field(terms)
.field(base)
.finish(),
Self::LinearTerm(x) => f.debug_tuple("LinearTerm").field(x).finish(),
Self::ConstantTerm(x) => f.debug_tuple("ConstantTerm").field(x).finish(),
}
}
}
impl<E, F: Field, B: Basis> From<AstLeaf<E, B>> for Ast<E, F, B> {
fn from(leaf: AstLeaf<E, B>) -> Self {
Ast::Poly(leaf)
}
}
impl<E, F: Field, B: Basis> Ast<E, F, B> {
pub(crate) fn one() -> Self {
Self::ConstantTerm(F::ONE)
}
}
impl<E, F: Field, B: Basis> Neg for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn neg(self) -> Self::Output {
Ast::Scale(Arc::new(self), -F::ONE)
}
}
impl<E: Clone, F: Field, B: Basis> Neg for &Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn neg(self) -> Self::Output {
-(self.clone())
}
}
impl<E, F: Field, B: Basis> Add for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn add(self, other: Self) -> Self::Output {
Ast::Add(Arc::new(self), Arc::new(other))
}
}
impl<'a, E: Clone, F: Field, B: Basis> Add<&'a Ast<E, F, B>> for &'a Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn add(self, other: &'a Ast<E, F, B>) -> Self::Output {
self.clone() + other.clone()
}
}
impl<E, F: Field, B: Basis> Add<AstLeaf<E, B>> for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn add(self, other: AstLeaf<E, B>) -> Self::Output {
Ast::Add(Arc::new(self), Arc::new(other.into()))
}
}
impl<E, F: Field, B: Basis> Sub for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn sub(self, other: Self) -> Self::Output {
self + (-other)
}
}
impl<'a, E: Clone, F: Field, B: Basis> Sub<&'a Ast<E, F, B>> for &'a Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn sub(self, other: &'a Ast<E, F, B>) -> Self::Output {
self + &(-other)
}
}
impl<E, F: Field, B: Basis> Sub<AstLeaf<E, B>> for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn sub(self, other: AstLeaf<E, B>) -> Self::Output {
self + (-Ast::from(other))
}
}
impl<E, F: Field> Mul for Ast<E, F, LagrangeCoeff> {
type Output = Ast<E, F, LagrangeCoeff>;
fn mul(self, other: Self) -> Self::Output {
Ast::Mul(AstMul(Arc::new(self), Arc::new(other)))
}
}
impl<'a, E: Clone, F: Field> Mul<&'a Ast<E, F, LagrangeCoeff>> for &'a Ast<E, F, LagrangeCoeff> {
type Output = Ast<E, F, LagrangeCoeff>;
fn mul(self, other: &'a Ast<E, F, LagrangeCoeff>) -> Self::Output {
self.clone() * other.clone()
}
}
impl<E, F: Field> Mul<AstLeaf<E, LagrangeCoeff>> for Ast<E, F, LagrangeCoeff> {
type Output = Ast<E, F, LagrangeCoeff>;
fn mul(self, other: AstLeaf<E, LagrangeCoeff>) -> Self::Output {
Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into())))
}
}
impl<E, F: Field> Mul for Ast<E, F, ExtendedLagrangeCoeff> {
type Output = Ast<E, F, ExtendedLagrangeCoeff>;
fn mul(self, other: Self) -> Self::Output {
Ast::Mul(AstMul(Arc::new(self), Arc::new(other)))
}
}
impl<'a, E: Clone, F: Field> Mul<&'a Ast<E, F, ExtendedLagrangeCoeff>>
for &'a Ast<E, F, ExtendedLagrangeCoeff>
{
type Output = Ast<E, F, ExtendedLagrangeCoeff>;
fn mul(self, other: &'a Ast<E, F, ExtendedLagrangeCoeff>) -> Self::Output {
self.clone() * other.clone()
}
}
impl<E, F: Field> Mul<AstLeaf<E, ExtendedLagrangeCoeff>> for Ast<E, F, ExtendedLagrangeCoeff> {
type Output = Ast<E, F, ExtendedLagrangeCoeff>;
fn mul(self, other: AstLeaf<E, ExtendedLagrangeCoeff>) -> Self::Output {
Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into())))
}
}
impl<E, F: Field, B: Basis> Mul<F> for Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn mul(self, other: F) -> Self::Output {
Ast::Scale(Arc::new(self), other)
}
}
impl<E: Clone, F: Field, B: Basis> Mul<F> for &Ast<E, F, B> {
type Output = Ast<E, F, B>;
fn mul(self, other: F) -> Self::Output {
Ast::Scale(Arc::new(self.clone()), other)
}
}
impl<E: Clone, F: Field> MulAssign for Ast<E, F, ExtendedLagrangeCoeff> {
fn mul_assign(&mut self, rhs: Self) {
*self = self.clone().mul(rhs)
}
}
/// Operations which can be performed over a given basis.
pub(crate) trait BasisOps: Basis {
fn empty_poly<F: WithSmallOrderMulGroup<3>>(
domain: &EvaluationDomain<F>,
) -> Polynomial<F, Self>;
fn constant_term<F: Field>(
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F>;
fn linear_term<F: WithSmallOrderMulGroup<3>>(
domain: &EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F>;
fn get_chunk_of_rotated<F: WithSmallOrderMulGroup<3>>(
domain: &EvaluationDomain<F>,
chunk_size: usize,
chunk_index: usize,
poly: &Polynomial<F, Self>,
rotation: Rotation,
) -> Vec<F>;
}
impl BasisOps for Coeff {
fn empty_poly<F: WithSmallOrderMulGroup<3>>(
domain: &EvaluationDomain<F>,
) -> Polynomial<F, Self> {
domain.empty_coeff()
}
fn constant_term<F: Field>(
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
let mut chunk = vec![F::ZERO; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)];
if chunk_index == 0 {
chunk[0] = scalar;
}
chunk
}
fn linear_term<F: WithSmallOrderMulGroup<3>>(
_: &EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
let mut chunk = vec![F::ZERO; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)];
// If the chunk size is 1 (e.g. if we have a small k and many threads), then the
// linear coefficient is the second chunk. Otherwise, the chunk size is greater
// than one, and the linear coefficient is the second element of the first chunk.
// Note that we check against the original chunk size, not the potentially-short
// actual size of the current chunk, because we want to know whether the size of
// the previous chunk was 1.
if chunk_size == 1 && chunk_index == 1 {
chunk[0] = scalar;
} else if chunk_index == 0 {
chunk[1] = scalar;
}
chunk
}
fn get_chunk_of_rotated<F: WithSmallOrderMulGroup<3>>(
_: &EvaluationDomain<F>,
_: usize,
_: usize,
_: &Polynomial<F, Self>,
_: Rotation,
) -> Vec<F> {
panic!("Can't rotate polynomials in the standard basis")
}
}
impl BasisOps for LagrangeCoeff {
fn empty_poly<F: WithSmallOrderMulGroup<3>>(
domain: &EvaluationDomain<F>,
) -> Polynomial<F, Self> {
domain.empty_lagrange()
}
fn constant_term<F: Field>(
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]
}
fn linear_term<F: WithSmallOrderMulGroup<3>>(
domain: &EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
// Take every power of omega within the chunk, and multiply by scalar.
let omega = domain.get_omega();
let start = chunk_size * chunk_index;
(0..cmp::min(chunk_size, poly_len - start))
.scan(omega.pow_vartime([start as u64]) * scalar, |acc, _| {
let ret = *acc;
*acc *= omega;
Some(ret)
})
.collect()
}
fn get_chunk_of_rotated<F: WithSmallOrderMulGroup<3>>(
_: &EvaluationDomain<F>,
chunk_size: usize,
chunk_index: usize,
poly: &Polynomial<F, Self>,
rotation: Rotation,
) -> Vec<F> {
poly.get_chunk_of_rotated(rotation, chunk_size, chunk_index)
}
}
impl BasisOps for ExtendedLagrangeCoeff {
fn empty_poly<F: WithSmallOrderMulGroup<3>>(
domain: &EvaluationDomain<F>,
) -> Polynomial<F, Self> {
domain.empty_extended()
}
fn constant_term<F: Field>(
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]
}
fn linear_term<F: WithSmallOrderMulGroup<3>>(
domain: &EvaluationDomain<F>,
poly_len: usize,
chunk_size: usize,
chunk_index: usize,
scalar: F,
) -> Vec<F> {
// Take every power of the extended omega within the chunk, and multiply by scalar.
let omega = domain.get_extended_omega();
let start = chunk_size * chunk_index;
(0..cmp::min(chunk_size, poly_len - start))
.scan(
omega.pow_vartime([start as u64]) * F::ZETA * scalar,
|acc, _| {
let ret = *acc;
*acc *= omega;
Some(ret)
},
)
.collect()
}
fn get_chunk_of_rotated<F: WithSmallOrderMulGroup<3>>(
domain: &EvaluationDomain<F>,
chunk_size: usize,
chunk_index: usize,
poly: &Polynomial<F, Self>,
rotation: Rotation,
) -> Vec<F> {
domain.get_chunk_of_rotated_extended(poly, rotation, chunk_size, chunk_index)
}
}
#[cfg(test)]
mod tests {
use group::ff::Field;
use pasta_curves::pallas;
use super::{get_chunk_params, new_evaluator, Ast, BasisOps, Evaluator};
use crate::poly::{Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff};
#[test]
fn short_chunk_regression_test() {
// Pick the smallest polynomial length that is guaranteed to produce a short chunk
// on this machine.
let k = match (1..16)
.map(|k| (k, get_chunk_params(1 << k)))
.find(|(k, (chunk_size, num_chunks))| (1 << k) < chunk_size * num_chunks)
.map(|(k, _)| k)
{
Some(k) => k,
None => {
// We are on a machine with a power-of-two number of threads, and cannot
// trigger the bug.
eprintln!(
"can't find a polynomial length for short_chunk_regression_test; skipping"
);
return;
}
};
eprintln!("Testing short-chunk regression with k = {}", k);
fn test_case<E: Copy + Send + Sync, B: BasisOps>(
k: u32,
mut evaluator: Evaluator<E, pallas::Base, B>,
) {
// Instantiate the evaluator with a trivial polynomial.
let domain = EvaluationDomain::new(1, k);
evaluator.register_poly(B::empty_poly(&domain));
// With the bug present, these will panic.
let _ = evaluator.evaluate(&Ast::ConstantTerm(pallas::Base::ZERO), &domain);
let _ = evaluator.evaluate(&Ast::LinearTerm(pallas::Base::ZERO), &domain);
}
test_case(k, new_evaluator::<_, _, Coeff>(|| {}));
test_case(k, new_evaluator::<_, _, LagrangeCoeff>(|| {}));
test_case(k, new_evaluator::<_, _, ExtendedLagrangeCoeff>(|| {}));
}
}