diff --git a/halo2_proofs/src/plonk/assigned.rs b/halo2_proofs/src/plonk/assigned.rs index 637d2a58..fcf6b2a3 100644 --- a/halo2_proofs/src/plonk/assigned.rs +++ b/halo2_proofs/src/plonk/assigned.rs @@ -389,26 +389,108 @@ mod tests { #[cfg(test)] mod proptests { use std::{ + cmp, convert::TryFrom, - ops::{Add, Mul, Sub}, + ops::{Add, Mul, Neg, Sub}, }; + use group::ff::Field; use pasta_curves::{arithmetic::FieldExt, Fp}; use proptest::{collection::vec, prelude::*, sample::select}; use super::Assigned; + trait UnaryOperand: Neg { + fn double(&self) -> Self; + fn square(&self) -> Self; + fn cube(&self) -> Self; + fn inv0(&self) -> Self; + } + + impl UnaryOperand for F { + fn double(&self) -> Self { + self.double() + } + + fn square(&self) -> Self { + self.square() + } + + fn cube(&self) -> Self { + self.cube() + } + + fn inv0(&self) -> Self { + self.invert().unwrap_or(F::zero()) + } + } + + impl UnaryOperand for Assigned { + fn double(&self) -> Self { + self.double() + } + + fn square(&self) -> Self { + self.square() + } + + fn cube(&self) -> Self { + self.cube() + } + + fn inv0(&self) -> Self { + self.invert() + } + } + #[derive(Clone, Debug)] - enum Operation { + enum UnaryOperator { + Neg, + Double, + Square, + Cube, + Inv0, + } + + const UNARY_OPERATORS: &[UnaryOperator] = &[ + UnaryOperator::Neg, + UnaryOperator::Double, + UnaryOperator::Square, + UnaryOperator::Cube, + UnaryOperator::Inv0, + ]; + + impl UnaryOperator { + fn apply(&self, a: F) -> F { + match self { + Self::Neg => -a, + Self::Double => a.double(), + Self::Square => a.square(), + Self::Cube => a.cube(), + Self::Inv0 => a.inv0(), + } + } + } + + trait BinaryOperand: Sized + Add + Sub + Mul {} + impl BinaryOperand for F {} + impl BinaryOperand for Assigned {} + + #[derive(Clone, Debug)] + enum BinaryOperator { Add, Sub, Mul, } - const OPERATIONS: &[Operation] = &[Operation::Add, Operation::Sub, Operation::Mul]; + const BINARY_OPERATORS: &[BinaryOperator] = &[ + BinaryOperator::Add, + BinaryOperator::Sub, + BinaryOperator::Mul, + ]; - impl Operation { - fn apply + Sub + Mul>(&self, a: F, b: F) -> F { + impl BinaryOperator { + fn apply(&self, a: F, b: F) -> F { match self { Self::Add => a + b, Self::Sub => a - b, @@ -417,6 +499,12 @@ mod proptests { } } + #[derive(Clone, Debug)] + enum Operator { + Unary(UnaryOperator), + Binary(BinaryOperator), + } + prop_compose! { /// Use narrow that can be easily reduced. fn arb_element()(val in any::()) -> Fp { @@ -440,15 +528,31 @@ mod proptests { } } + prop_compose! { + fn arb_operators(num_unary: usize, num_binary: usize)( + unary in vec(select(UNARY_OPERATORS), num_unary), + binary in vec(select(BINARY_OPERATORS), num_binary), + ) -> Vec { + unary.into_iter() + .map(Operator::Unary) + .chain(binary.into_iter().map(Operator::Binary)) + .collect() + } + } + prop_compose! { fn arb_testcase()( - num_operations in 1usize..5, + num_unary in 0usize..5, + num_binary in 0usize..5, )( values in vec( prop_oneof![Just(Assigned::Zero), arb_trivial(), arb_rational()], - num_operations + 1), - operations in vec(select(OPERATIONS), num_operations), - ) -> (Vec>, Vec) { + // Ensure that: + // - we have at least one value to apply unary operators to. + // - we can apply every binary operator pairwise sequentially. + cmp::max(if num_unary > 0 { 1 } else { 0 }, num_binary + 1)), + operations in arb_operators(num_unary, num_binary).prop_shuffle(), + ) -> (Vec>, Vec) { (values, operations) } } @@ -460,14 +564,36 @@ mod proptests { let elements: Vec<_> = values.iter().cloned().map(|v| v.evaluate()).collect(); // Apply the operations to both the deferred and evaluated values. - let deferred_result = { - let mut ops = operations.iter(); - values.into_iter().reduce(|a, b| ops.next().unwrap().apply(a, b)).unwrap() - }; - let evaluated_result = { - let mut ops = operations.iter(); - elements.into_iter().reduce(|a, b| ops.next().unwrap().apply(a, b)).unwrap() - }; + fn evaluate( + items: Vec, + operators: &[Operator], + ) -> F { + let mut ops = operators.iter(); + + // Process all binary operators. We are guaranteed to have exactly as many + // binary operators as we need calls to the reduction closure. + let mut res = items.into_iter().reduce(|mut a, b| loop { + match ops.next() { + Some(Operator::Unary(op)) => a = op.apply(a), + Some(Operator::Binary(op)) => break op.apply(a, b), + None => unreachable!(), + } + }).unwrap(); + + // Process any unary operators that weren't handled in the reduce() call + // above (either if we only had one item, or there were unary operators + // after the last binary operator). We are guaranteed to have no binary + // operators remaining at this point. + loop { + match ops.next() { + Some(Operator::Unary(op)) => res = op.apply(res), + Some(Operator::Binary(_)) => unreachable!(), + None => break res, + } + } + } + let deferred_result = evaluate(values, &operations); + let evaluated_result = evaluate(elements, &operations); // The two should be equal, i.e. deferred inversion should commute with the // list of operations.