mirror of https://github.com/zcash/halo2.git
Merge pull request #417 from zcash/fix-assigned-usage
Expand `Assigned<F>` APIs
This commit is contained in:
commit
8abd7b74db
|
@ -122,6 +122,20 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<F: Field> AssignedCell<Assigned<F>, F> {
|
||||||
|
/// Evaluates this assigned cell's value directly, performing an unbatched inversion
|
||||||
|
/// if necessary.
|
||||||
|
///
|
||||||
|
/// If the denominator is zero, the returned cell's value is zero.
|
||||||
|
pub fn evaluate(self) -> AssignedCell<F, F> {
|
||||||
|
AssignedCell {
|
||||||
|
value: self.value.map(|v| v.evaluate()),
|
||||||
|
cell: self.cell,
|
||||||
|
_marker: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<V: Clone, F: Field> AssignedCell<V, F>
|
impl<V: Clone, F: Field> AssignedCell<V, F>
|
||||||
where
|
where
|
||||||
for<'v> Assigned<F>: From<&'v V>,
|
for<'v> Assigned<F>: From<&'v V>,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::ops::{Add, Mul, Neg, Sub};
|
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||||
|
|
||||||
use group::ff::Field;
|
use group::ff::Field;
|
||||||
|
|
||||||
|
@ -17,6 +17,12 @@ pub enum Assigned<F> {
|
||||||
Rational(F, F),
|
Rational(F, F),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<F: Field> From<&Assigned<F>> for Assigned<F> {
|
||||||
|
fn from(val: &Assigned<F>) -> Self {
|
||||||
|
*val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<F: Field> From<&F> for Assigned<F> {
|
impl<F: Field> From<&F> for Assigned<F> {
|
||||||
fn from(numerator: &F) -> Self {
|
fn from(numerator: &F) -> Self {
|
||||||
Assigned::Trivial(*numerator)
|
Assigned::Trivial(*numerator)
|
||||||
|
@ -35,6 +41,36 @@ impl<F: Field> From<(F, F)> for Assigned<F> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<F: Field> PartialEq for Assigned<F> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
match (self, other) {
|
||||||
|
// At least one side is directly zero.
|
||||||
|
(Self::Zero, Self::Zero) => true,
|
||||||
|
(Self::Zero, x) | (x, Self::Zero) => x.is_zero_vartime(),
|
||||||
|
|
||||||
|
// One side is x/0 which maps to zero.
|
||||||
|
(Self::Rational(_, denominator), x) | (x, Self::Rational(_, denominator))
|
||||||
|
if denominator.is_zero_vartime() =>
|
||||||
|
{
|
||||||
|
x.is_zero_vartime()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Okay, we need to do some actual math...
|
||||||
|
(Self::Trivial(lhs), Self::Trivial(rhs)) => lhs == rhs,
|
||||||
|
(Self::Trivial(x), Self::Rational(numerator, denominator))
|
||||||
|
| (Self::Rational(numerator, denominator), Self::Trivial(x)) => {
|
||||||
|
&(*x * denominator) == numerator
|
||||||
|
}
|
||||||
|
(
|
||||||
|
Self::Rational(lhs_numerator, lhs_denominator),
|
||||||
|
Self::Rational(rhs_numerator, rhs_denominator),
|
||||||
|
) => *lhs_numerator * rhs_denominator == *lhs_denominator * rhs_numerator,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Field> Eq for Assigned<F> {}
|
||||||
|
|
||||||
impl<F: Field> Neg for Assigned<F> {
|
impl<F: Field> Neg for Assigned<F> {
|
||||||
type Output = Assigned<F>;
|
type Output = Assigned<F>;
|
||||||
fn neg(self) -> Self::Output {
|
fn neg(self) -> Self::Output {
|
||||||
|
@ -85,6 +121,25 @@ impl<F: Field> Add<F> for Assigned<F> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<F: Field> Add<&Assigned<F>> for Assigned<F> {
|
||||||
|
type Output = Assigned<F>;
|
||||||
|
fn add(self, rhs: &Self) -> Assigned<F> {
|
||||||
|
self + *rhs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Field> AddAssign for Assigned<F> {
|
||||||
|
fn add_assign(&mut self, rhs: Self) {
|
||||||
|
*self = *self + rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Field> AddAssign<&Assigned<F>> for Assigned<F> {
|
||||||
|
fn add_assign(&mut self, rhs: &Self) {
|
||||||
|
*self = *self + rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<F: Field> Sub for Assigned<F> {
|
impl<F: Field> Sub for Assigned<F> {
|
||||||
type Output = Assigned<F>;
|
type Output = Assigned<F>;
|
||||||
fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
|
fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
|
||||||
|
@ -99,6 +154,25 @@ impl<F: Field> Sub<F> for Assigned<F> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<F: Field> Sub<&Assigned<F>> for Assigned<F> {
|
||||||
|
type Output = Assigned<F>;
|
||||||
|
fn sub(self, rhs: &Self) -> Assigned<F> {
|
||||||
|
self - *rhs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Field> SubAssign for Assigned<F> {
|
||||||
|
fn sub_assign(&mut self, rhs: Self) {
|
||||||
|
*self = *self - rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Field> SubAssign<&Assigned<F>> for Assigned<F> {
|
||||||
|
fn sub_assign(&mut self, rhs: &Self) {
|
||||||
|
*self = *self - rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<F: Field> Mul for Assigned<F> {
|
impl<F: Field> Mul for Assigned<F> {
|
||||||
type Output = Assigned<F>;
|
type Output = Assigned<F>;
|
||||||
fn mul(self, rhs: Assigned<F>) -> Assigned<F> {
|
fn mul(self, rhs: Assigned<F>) -> Assigned<F> {
|
||||||
|
@ -127,6 +201,25 @@ impl<F: Field> Mul<F> for Assigned<F> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<F: Field> Mul<&Assigned<F>> for Assigned<F> {
|
||||||
|
type Output = Assigned<F>;
|
||||||
|
fn mul(self, rhs: &Assigned<F>) -> Assigned<F> {
|
||||||
|
self * *rhs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Field> MulAssign for Assigned<F> {
|
||||||
|
fn mul_assign(&mut self, rhs: Self) {
|
||||||
|
*self = *self * rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Field> MulAssign<&Assigned<F>> for Assigned<F> {
|
||||||
|
fn mul_assign(&mut self, rhs: &Self) {
|
||||||
|
*self = *self * rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<F: Field> Assigned<F> {
|
impl<F: Field> Assigned<F> {
|
||||||
/// Returns the numerator.
|
/// Returns the numerator.
|
||||||
pub fn numerator(&self) -> F {
|
pub fn numerator(&self) -> F {
|
||||||
|
@ -146,6 +239,48 @@ impl<F: Field> Assigned<F> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true iff this element is zero.
|
||||||
|
pub fn is_zero_vartime(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Zero => true,
|
||||||
|
Self::Trivial(x) => x.is_zero_vartime(),
|
||||||
|
// Assigned maps x/0 -> 0.
|
||||||
|
Self::Rational(numerator, denominator) => {
|
||||||
|
numerator.is_zero_vartime() || denominator.is_zero_vartime()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Doubles this element.
|
||||||
|
#[must_use]
|
||||||
|
pub fn double(&self) -> Self {
|
||||||
|
match self {
|
||||||
|
Self::Zero => Self::Zero,
|
||||||
|
Self::Trivial(x) => Self::Trivial(x.double()),
|
||||||
|
Self::Rational(numerator, denominator) => {
|
||||||
|
Self::Rational(numerator.double(), *denominator)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Squares this element.
|
||||||
|
#[must_use]
|
||||||
|
pub fn square(&self) -> Self {
|
||||||
|
match self {
|
||||||
|
Self::Zero => Self::Zero,
|
||||||
|
Self::Trivial(x) => Self::Trivial(x.square()),
|
||||||
|
Self::Rational(numerator, denominator) => {
|
||||||
|
Self::Rational(numerator.square(), denominator.square())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cubes this element.
|
||||||
|
#[must_use]
|
||||||
|
pub fn cube(&self) -> Self {
|
||||||
|
self.square() * self
|
||||||
|
}
|
||||||
|
|
||||||
/// Inverts this assigned value (taking the inverse of zero to be zero).
|
/// Inverts this assigned value (taking the inverse of zero to be zero).
|
||||||
pub fn invert(&self) -> Self {
|
pub fn invert(&self) -> Self {
|
||||||
match self {
|
match self {
|
||||||
|
@ -254,26 +389,108 @@ mod tests {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod proptests {
|
mod proptests {
|
||||||
use std::{
|
use std::{
|
||||||
|
cmp,
|
||||||
convert::TryFrom,
|
convert::TryFrom,
|
||||||
ops::{Add, Mul, Sub},
|
ops::{Add, Mul, Neg, Sub},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use group::ff::Field;
|
||||||
use pasta_curves::{arithmetic::FieldExt, Fp};
|
use pasta_curves::{arithmetic::FieldExt, Fp};
|
||||||
use proptest::{collection::vec, prelude::*, sample::select};
|
use proptest::{collection::vec, prelude::*, sample::select};
|
||||||
|
|
||||||
use super::Assigned;
|
use super::Assigned;
|
||||||
|
|
||||||
|
trait UnaryOperand: Neg<Output = Self> {
|
||||||
|
fn double(&self) -> Self;
|
||||||
|
fn square(&self) -> Self;
|
||||||
|
fn cube(&self) -> Self;
|
||||||
|
fn inv0(&self) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Field> UnaryOperand for F {
|
||||||
|
fn double(&self) -> Self {
|
||||||
|
self.double()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn square(&self) -> Self {
|
||||||
|
self.square()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cube(&self) -> Self {
|
||||||
|
self.cube()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inv0(&self) -> Self {
|
||||||
|
self.invert().unwrap_or(F::zero())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Field> UnaryOperand for Assigned<F> {
|
||||||
|
fn double(&self) -> Self {
|
||||||
|
self.double()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn square(&self) -> Self {
|
||||||
|
self.square()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cube(&self) -> Self {
|
||||||
|
self.cube()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inv0(&self) -> Self {
|
||||||
|
self.invert()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
enum Operation {
|
enum UnaryOperator {
|
||||||
|
Neg,
|
||||||
|
Double,
|
||||||
|
Square,
|
||||||
|
Cube,
|
||||||
|
Inv0,
|
||||||
|
}
|
||||||
|
|
||||||
|
const UNARY_OPERATORS: &[UnaryOperator] = &[
|
||||||
|
UnaryOperator::Neg,
|
||||||
|
UnaryOperator::Double,
|
||||||
|
UnaryOperator::Square,
|
||||||
|
UnaryOperator::Cube,
|
||||||
|
UnaryOperator::Inv0,
|
||||||
|
];
|
||||||
|
|
||||||
|
impl UnaryOperator {
|
||||||
|
fn apply<F: UnaryOperand>(&self, a: F) -> F {
|
||||||
|
match self {
|
||||||
|
Self::Neg => -a,
|
||||||
|
Self::Double => a.double(),
|
||||||
|
Self::Square => a.square(),
|
||||||
|
Self::Cube => a.cube(),
|
||||||
|
Self::Inv0 => a.inv0(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait BinaryOperand: Sized + Add<Output = Self> + Sub<Output = Self> + Mul<Output = Self> {}
|
||||||
|
impl<F: Field> BinaryOperand for F {}
|
||||||
|
impl<F: Field> BinaryOperand for Assigned<F> {}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
enum BinaryOperator {
|
||||||
Add,
|
Add,
|
||||||
Sub,
|
Sub,
|
||||||
Mul,
|
Mul,
|
||||||
}
|
}
|
||||||
|
|
||||||
const OPERATIONS: &[Operation] = &[Operation::Add, Operation::Sub, Operation::Mul];
|
const BINARY_OPERATORS: &[BinaryOperator] = &[
|
||||||
|
BinaryOperator::Add,
|
||||||
|
BinaryOperator::Sub,
|
||||||
|
BinaryOperator::Mul,
|
||||||
|
];
|
||||||
|
|
||||||
impl Operation {
|
impl BinaryOperator {
|
||||||
fn apply<F: Add<Output = F> + Sub<Output = F> + Mul<Output = F>>(&self, a: F, b: F) -> F {
|
fn apply<F: BinaryOperand>(&self, a: F, b: F) -> F {
|
||||||
match self {
|
match self {
|
||||||
Self::Add => a + b,
|
Self::Add => a + b,
|
||||||
Self::Sub => a - b,
|
Self::Sub => a - b,
|
||||||
|
@ -282,6 +499,12 @@ mod proptests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
enum Operator {
|
||||||
|
Unary(UnaryOperator),
|
||||||
|
Binary(BinaryOperator),
|
||||||
|
}
|
||||||
|
|
||||||
prop_compose! {
|
prop_compose! {
|
||||||
/// Use narrow that can be easily reduced.
|
/// Use narrow that can be easily reduced.
|
||||||
fn arb_element()(val in any::<u64>()) -> Fp {
|
fn arb_element()(val in any::<u64>()) -> Fp {
|
||||||
|
@ -299,21 +522,44 @@ mod proptests {
|
||||||
/// Generates half of the denominators as zero to represent a deferred inversion.
|
/// Generates half of the denominators as zero to represent a deferred inversion.
|
||||||
fn arb_rational()(
|
fn arb_rational()(
|
||||||
numerator in arb_element(),
|
numerator in arb_element(),
|
||||||
denominator in prop_oneof![Just(Fp::zero()), arb_element()],
|
denominator in prop_oneof![
|
||||||
|
1 => Just(Fp::zero()),
|
||||||
|
2 => arb_element(),
|
||||||
|
],
|
||||||
) -> Assigned<Fp> {
|
) -> Assigned<Fp> {
|
||||||
Assigned::Rational(numerator, denominator)
|
Assigned::Rational(numerator, denominator)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prop_compose! {
|
||||||
|
fn arb_operators(num_unary: usize, num_binary: usize)(
|
||||||
|
unary in vec(select(UNARY_OPERATORS), num_unary),
|
||||||
|
binary in vec(select(BINARY_OPERATORS), num_binary),
|
||||||
|
) -> Vec<Operator> {
|
||||||
|
unary.into_iter()
|
||||||
|
.map(Operator::Unary)
|
||||||
|
.chain(binary.into_iter().map(Operator::Binary))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
prop_compose! {
|
prop_compose! {
|
||||||
fn arb_testcase()(
|
fn arb_testcase()(
|
||||||
num_operations in 1usize..5,
|
num_unary in 0usize..5,
|
||||||
|
num_binary in 0usize..5,
|
||||||
)(
|
)(
|
||||||
values in vec(
|
values in vec(
|
||||||
prop_oneof![Just(Assigned::Zero), arb_trivial(), arb_rational()],
|
prop_oneof![
|
||||||
num_operations + 1),
|
1 => Just(Assigned::Zero),
|
||||||
operations in vec(select(OPERATIONS), num_operations),
|
2 => arb_trivial(),
|
||||||
) -> (Vec<Assigned<Fp>>, Vec<Operation>) {
|
2 => arb_rational(),
|
||||||
|
],
|
||||||
|
// Ensure that:
|
||||||
|
// - we have at least one value to apply unary operators to.
|
||||||
|
// - we can apply every binary operator pairwise sequentially.
|
||||||
|
cmp::max(if num_unary > 0 { 1 } else { 0 }, num_binary + 1)),
|
||||||
|
operations in arb_operators(num_unary, num_binary).prop_shuffle(),
|
||||||
|
) -> (Vec<Assigned<Fp>>, Vec<Operator>) {
|
||||||
(values, operations)
|
(values, operations)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -325,14 +571,36 @@ mod proptests {
|
||||||
let elements: Vec<_> = values.iter().cloned().map(|v| v.evaluate()).collect();
|
let elements: Vec<_> = values.iter().cloned().map(|v| v.evaluate()).collect();
|
||||||
|
|
||||||
// Apply the operations to both the deferred and evaluated values.
|
// Apply the operations to both the deferred and evaluated values.
|
||||||
let deferred_result = {
|
fn evaluate<F: UnaryOperand + BinaryOperand>(
|
||||||
let mut ops = operations.iter();
|
items: Vec<F>,
|
||||||
values.into_iter().reduce(|a, b| ops.next().unwrap().apply(a, b)).unwrap()
|
operators: &[Operator],
|
||||||
};
|
) -> F {
|
||||||
let evaluated_result = {
|
let mut ops = operators.iter();
|
||||||
let mut ops = operations.iter();
|
|
||||||
elements.into_iter().reduce(|a, b| ops.next().unwrap().apply(a, b)).unwrap()
|
// Process all binary operators. We are guaranteed to have exactly as many
|
||||||
};
|
// binary operators as we need calls to the reduction closure.
|
||||||
|
let mut res = items.into_iter().reduce(|mut a, b| loop {
|
||||||
|
match ops.next() {
|
||||||
|
Some(Operator::Unary(op)) => a = op.apply(a),
|
||||||
|
Some(Operator::Binary(op)) => break op.apply(a, b),
|
||||||
|
None => unreachable!(),
|
||||||
|
}
|
||||||
|
}).unwrap();
|
||||||
|
|
||||||
|
// Process any unary operators that weren't handled in the reduce() call
|
||||||
|
// above (either if we only had one item, or there were unary operators
|
||||||
|
// after the last binary operator). We are guaranteed to have no binary
|
||||||
|
// operators remaining at this point.
|
||||||
|
loop {
|
||||||
|
match ops.next() {
|
||||||
|
Some(Operator::Unary(op)) => res = op.apply(res),
|
||||||
|
Some(Operator::Binary(_)) => unreachable!(),
|
||||||
|
None => break res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let deferred_result = evaluate(values, &operations);
|
||||||
|
let evaluated_result = evaluate(elements, &operations);
|
||||||
|
|
||||||
// The two should be equal, i.e. deferred inversion should commute with the
|
// The two should be equal, i.e. deferred inversion should commute with the
|
||||||
// list of operations.
|
// list of operations.
|
||||||
|
|
Loading…
Reference in New Issue