diff --git a/halo2_proofs/src/plonk/assigned.rs b/halo2_proofs/src/plonk/assigned.rs index 3c1795c1..67e7a560 100644 --- a/halo2_proofs/src/plonk/assigned.rs +++ b/halo2_proofs/src/plonk/assigned.rs @@ -41,6 +41,36 @@ impl From<(F, F)> for Assigned { } } +impl PartialEq for Assigned { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + // At least one side is directly zero. + (Self::Zero, Self::Zero) => true, + (Self::Zero, x) | (x, Self::Zero) => x.is_zero_vartime(), + + // One side is x/0 which maps to zero. + (Self::Rational(_, denominator), x) | (x, Self::Rational(_, denominator)) + if denominator.is_zero_vartime() => + { + x.is_zero_vartime() + } + + // Okay, we need to do some actual math... + (Self::Trivial(lhs), Self::Trivial(rhs)) => lhs == rhs, + (Self::Trivial(x), Self::Rational(numerator, denominator)) + | (Self::Rational(numerator, denominator), Self::Trivial(x)) => { + &(*x * denominator) == numerator + } + ( + Self::Rational(lhs_numerator, lhs_denominator), + Self::Rational(rhs_numerator, rhs_denominator), + ) => *lhs_numerator * rhs_denominator == *lhs_denominator * rhs_numerator, + } + } +} + +impl Eq for Assigned {} + impl Neg for Assigned { type Output = Assigned; fn neg(self) -> Self::Output {