mirror of https://github.com/zcash/halo2.git
Merge pull request #78 from Brechtpd/eval-mem
Reduce memory use in h evaluation
This commit is contained in:
commit
ad425ed3d1
|
@ -46,6 +46,22 @@ pub enum ValueSource {
|
|||
Advice(usize, usize),
|
||||
/// This is an instance (external) column
|
||||
Instance(usize, usize),
|
||||
/// beta
|
||||
Beta(),
|
||||
/// gamma
|
||||
Gamma(),
|
||||
/// theta
|
||||
Theta(),
|
||||
/// y
|
||||
Y(),
|
||||
/// Previous value
|
||||
PreviousValue(),
|
||||
}
|
||||
|
||||
impl Default for ValueSource {
|
||||
fn default() -> Self {
|
||||
ValueSource::Constant(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueSource {
|
||||
|
@ -58,6 +74,11 @@ impl ValueSource {
|
|||
fixed_values: &[Polynomial<F, B>],
|
||||
advice_values: &[Polynomial<F, B>],
|
||||
instance_values: &[Polynomial<F, B>],
|
||||
beta: &F,
|
||||
gamma: &F,
|
||||
theta: &F,
|
||||
y: &F,
|
||||
previous_value: &F,
|
||||
) -> F {
|
||||
match self {
|
||||
ValueSource::Constant(idx) => constants[*idx],
|
||||
|
@ -71,6 +92,11 @@ impl ValueSource {
|
|||
ValueSource::Instance(column_index, rotation) => {
|
||||
instance_values[*column_index][rotations[*rotation]]
|
||||
}
|
||||
ValueSource::Beta() => *beta,
|
||||
ValueSource::Gamma() => *gamma,
|
||||
ValueSource::Theta() => *theta,
|
||||
ValueSource::Y() => *y,
|
||||
ValueSource::PreviousValue() => *previous_value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,14 +110,14 @@ pub enum Calculation {
|
|||
Sub(ValueSource, ValueSource),
|
||||
/// This is a product
|
||||
Mul(ValueSource, ValueSource),
|
||||
/// This is a square
|
||||
Square(ValueSource),
|
||||
/// This is a double
|
||||
Double(ValueSource),
|
||||
/// This is a negation
|
||||
Negate(ValueSource),
|
||||
/// This is `(a + beta) * b`
|
||||
LcBeta(ValueSource, ValueSource),
|
||||
/// This is `a * theta + b`
|
||||
LcTheta(ValueSource, ValueSource),
|
||||
/// This is `a + gamma`
|
||||
AddGamma(ValueSource),
|
||||
/// This is Horner's rule: `val = a; val = val * c + b[]`
|
||||
Horner(ValueSource, Vec<ValueSource>, ValueSource),
|
||||
/// This is a simple assignment
|
||||
Store(ValueSource),
|
||||
}
|
||||
|
@ -109,146 +135,73 @@ impl Calculation {
|
|||
beta: &F,
|
||||
gamma: &F,
|
||||
theta: &F,
|
||||
y: &F,
|
||||
previous_value: &F,
|
||||
) -> F {
|
||||
let get_value = |value: &ValueSource| {
|
||||
value.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
beta,
|
||||
gamma,
|
||||
theta,
|
||||
y,
|
||||
previous_value,
|
||||
)
|
||||
};
|
||||
match self {
|
||||
Calculation::Add(a, b) => {
|
||||
let a = a.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
let b = b.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
a + b
|
||||
Calculation::Add(a, b) => get_value(a) + get_value(b),
|
||||
Calculation::Sub(a, b) => get_value(a) - get_value(b),
|
||||
Calculation::Mul(a, b) => get_value(a) * get_value(b),
|
||||
Calculation::Square(v) => get_value(v).square(),
|
||||
Calculation::Double(v) => get_value(v).double(),
|
||||
Calculation::Negate(v) => -get_value(v),
|
||||
Calculation::Horner(start_value, parts, factor) => {
|
||||
let factor = get_value(factor);
|
||||
let mut value = get_value(start_value);
|
||||
for part in parts.iter() {
|
||||
value = value * factor + get_value(part);
|
||||
}
|
||||
value
|
||||
}
|
||||
Calculation::Sub(a, b) => {
|
||||
let a = a.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
let b = b.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
a - b
|
||||
}
|
||||
Calculation::Mul(a, b) => {
|
||||
let a = a.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
let b = b.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
a * b
|
||||
}
|
||||
Calculation::Negate(v) => -v.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
),
|
||||
Calculation::LcBeta(a, b) => {
|
||||
let a = a.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
let b = b.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
(a + beta) * b
|
||||
}
|
||||
Calculation::LcTheta(a, b) => {
|
||||
let a = a.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
let b = b.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
);
|
||||
a * theta + b
|
||||
}
|
||||
Calculation::AddGamma(v) => {
|
||||
v.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
) + gamma
|
||||
}
|
||||
Calculation::Store(v) => v.get(
|
||||
rotations,
|
||||
constants,
|
||||
intermediates,
|
||||
fixed_values,
|
||||
advice_values,
|
||||
instance_values,
|
||||
),
|
||||
Calculation::Store(v) => get_value(v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// EvaluationData
|
||||
/// Evaluator
|
||||
#[derive(Default, Debug)]
|
||||
pub struct Evaluator<C: CurveAffine> {
|
||||
/// Custom gates evalution
|
||||
pub custom_gates: GraphEvaluator<C>,
|
||||
/// Lookups evalution
|
||||
pub lookups: Vec<GraphEvaluator<C>>,
|
||||
}
|
||||
|
||||
/// GraphEvaluator
|
||||
#[derive(Debug)]
|
||||
pub struct GraphEvaluator<C: CurveAffine> {
|
||||
/// Constants
|
||||
pub constants: Vec<C::ScalarExt>,
|
||||
/// Rotations
|
||||
pub rotations: Vec<i32>,
|
||||
/// Calculations
|
||||
pub calculations: Vec<CalculationInfo>,
|
||||
/// Value parts
|
||||
pub value_parts: Vec<ValueSource>,
|
||||
/// Lookup results
|
||||
pub lookup_results: Vec<Calculation>,
|
||||
/// Number of intermediates
|
||||
pub num_intermediates: usize,
|
||||
}
|
||||
|
||||
/// EvaluationData
|
||||
#[derive(Default, Debug)]
|
||||
pub struct EvaluationData<C: CurveAffine> {
|
||||
/// Intermediates
|
||||
pub intermediates: Vec<C::ScalarExt>,
|
||||
/// Rotations
|
||||
pub rotations: Vec<usize>,
|
||||
}
|
||||
|
||||
/// CaluclationInfo
|
||||
|
@ -256,205 +209,67 @@ pub struct Evaluator<C: CurveAffine> {
|
|||
pub struct CalculationInfo {
|
||||
/// Calculation
|
||||
pub calculation: Calculation,
|
||||
/// How many times this calculation is used
|
||||
pub counter: usize,
|
||||
/// Target
|
||||
pub target: usize,
|
||||
}
|
||||
|
||||
impl<C: CurveAffine> Evaluator<C> {
|
||||
/// Creates a new evaluation structure
|
||||
pub fn new(cs: &ConstraintSystem<C::ScalarExt>) -> Self {
|
||||
let mut ev = Evaluator::default();
|
||||
ev.add_constant(&C::ScalarExt::zero());
|
||||
ev.add_constant(&C::ScalarExt::one());
|
||||
|
||||
// Custom gates
|
||||
let mut parts = Vec::new();
|
||||
for gate in cs.gates.iter() {
|
||||
for poly in gate.polynomials().iter() {
|
||||
let vs = ev.add_expression(poly);
|
||||
ev.value_parts.push(vs);
|
||||
}
|
||||
parts.extend(
|
||||
gate.polynomials()
|
||||
.iter()
|
||||
.map(|poly| ev.custom_gates.add_expression(poly)),
|
||||
);
|
||||
}
|
||||
ev.custom_gates.add_calculation(Calculation::Horner(
|
||||
ValueSource::PreviousValue(),
|
||||
parts,
|
||||
ValueSource::Y(),
|
||||
));
|
||||
|
||||
// Lookups
|
||||
for lookup in cs.lookups.iter() {
|
||||
let evaluate_lc = |ev: &mut Evaluator<_>, expressions: &Vec<Expression<_>>| {
|
||||
let mut graph = GraphEvaluator::default();
|
||||
|
||||
let mut evaluate_lc = |expressions: &Vec<Expression<_>>| {
|
||||
let parts = expressions
|
||||
.iter()
|
||||
.map(|expr| ev.add_expression(expr))
|
||||
.collect::<Vec<_>>();
|
||||
let mut lc = parts[0];
|
||||
for part in parts.iter().skip(1) {
|
||||
lc = ev.add_calculation(Calculation::LcTheta(lc, *part));
|
||||
}
|
||||
lc
|
||||
.map(|expr| graph.add_expression(expr))
|
||||
.collect();
|
||||
graph.add_calculation(Calculation::Horner(
|
||||
ValueSource::Constant(0),
|
||||
parts,
|
||||
ValueSource::Theta(),
|
||||
))
|
||||
};
|
||||
|
||||
// Input coset
|
||||
let compressed_input_coset = evaluate_lc(&mut ev, &lookup.input_expressions);
|
||||
let compressed_input_coset = evaluate_lc(&lookup.input_expressions);
|
||||
// table coset
|
||||
let compressed_table_coset = evaluate_lc(&mut ev, &lookup.table_expressions);
|
||||
let compressed_table_coset = evaluate_lc(&lookup.table_expressions);
|
||||
// z(\omega X) (a'(X) + \beta) (s'(X) + \gamma)
|
||||
let right_gamma = ev.add_calculation(Calculation::AddGamma(compressed_table_coset));
|
||||
ev.lookup_results
|
||||
.push(Calculation::LcBeta(compressed_input_coset, right_gamma));
|
||||
let right_gamma = graph.add_calculation(Calculation::Add(
|
||||
compressed_table_coset,
|
||||
ValueSource::Gamma(),
|
||||
));
|
||||
let lc = graph.add_calculation(Calculation::Add(
|
||||
compressed_input_coset,
|
||||
ValueSource::Beta(),
|
||||
));
|
||||
graph.add_calculation(Calculation::Mul(lc, right_gamma));
|
||||
|
||||
ev.lookups.push(graph);
|
||||
}
|
||||
|
||||
ev
|
||||
}
|
||||
|
||||
/// Adds a rotation
|
||||
fn add_rotation(&mut self, rotation: &Rotation) -> usize {
|
||||
let position = self.rotations.iter().position(|&c| c == rotation.0);
|
||||
match position {
|
||||
Some(pos) => pos,
|
||||
None => {
|
||||
self.rotations.push(rotation.0);
|
||||
self.rotations.len() - 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a constant
|
||||
fn add_constant(&mut self, constant: &C::ScalarExt) -> ValueSource {
|
||||
let position = self.constants.iter().position(|&c| c == *constant);
|
||||
ValueSource::Constant(match position {
|
||||
Some(pos) => pos,
|
||||
None => {
|
||||
self.constants.push(*constant);
|
||||
self.constants.len() - 1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Adds a calculation.
|
||||
/// Currently does the simplest thing possible: just stores the
|
||||
/// resulting value so the result can be reused when that calculation
|
||||
/// is done multiple times.
|
||||
fn add_calculation(&mut self, calculation: Calculation) -> ValueSource {
|
||||
let position = self
|
||||
.calculations
|
||||
.iter()
|
||||
.position(|c| c.calculation == calculation);
|
||||
match position {
|
||||
Some(pos) => {
|
||||
self.calculations[pos].counter += 1;
|
||||
ValueSource::Intermediate(pos)
|
||||
}
|
||||
None => {
|
||||
self.calculations.push(CalculationInfo {
|
||||
counter: 1,
|
||||
calculation,
|
||||
});
|
||||
ValueSource::Intermediate(self.calculations.len() - 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates an optimized evaluation for the expression
|
||||
fn add_expression(&mut self, expr: &Expression<C::ScalarExt>) -> ValueSource {
|
||||
match expr {
|
||||
Expression::Constant(scalar) => self.add_constant(scalar),
|
||||
Expression::Selector(_selector) => unreachable!(),
|
||||
Expression::Fixed {
|
||||
query_index: _,
|
||||
column_index,
|
||||
rotation,
|
||||
} => {
|
||||
let rot_idx = self.add_rotation(rotation);
|
||||
self.add_calculation(Calculation::Store(ValueSource::Fixed(
|
||||
*column_index,
|
||||
rot_idx,
|
||||
)))
|
||||
}
|
||||
Expression::Advice {
|
||||
query_index: _,
|
||||
column_index,
|
||||
rotation,
|
||||
} => {
|
||||
let rot_idx = self.add_rotation(rotation);
|
||||
self.add_calculation(Calculation::Store(ValueSource::Advice(
|
||||
*column_index,
|
||||
rot_idx,
|
||||
)))
|
||||
}
|
||||
Expression::Instance {
|
||||
query_index: _,
|
||||
column_index,
|
||||
rotation,
|
||||
} => {
|
||||
let rot_idx = self.add_rotation(rotation);
|
||||
self.add_calculation(Calculation::Store(ValueSource::Instance(
|
||||
*column_index,
|
||||
rot_idx,
|
||||
)))
|
||||
}
|
||||
Expression::Negated(a) => match **a {
|
||||
Expression::Constant(scalar) => self.add_constant(&-scalar),
|
||||
_ => {
|
||||
let result_a = self.add_expression(a);
|
||||
match result_a {
|
||||
ValueSource::Constant(0) => result_a,
|
||||
_ => self.add_calculation(Calculation::Negate(result_a)),
|
||||
}
|
||||
}
|
||||
},
|
||||
Expression::Sum(a, b) => {
|
||||
// Undo subtraction stored as a + (-b) in expressions
|
||||
match &**b {
|
||||
Expression::Negated(b_int) => {
|
||||
let result_a = self.add_expression(a);
|
||||
let result_b = self.add_expression(b_int);
|
||||
if result_a == ValueSource::Constant(0) {
|
||||
self.add_calculation(Calculation::Negate(result_b))
|
||||
} else if result_b == ValueSource::Constant(0) {
|
||||
result_a
|
||||
} else {
|
||||
self.add_calculation(Calculation::Sub(result_a, result_b))
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let result_a = self.add_expression(a);
|
||||
let result_b = self.add_expression(b);
|
||||
if result_a == ValueSource::Constant(0) {
|
||||
result_b
|
||||
} else if result_b == ValueSource::Constant(0) {
|
||||
result_a
|
||||
} else if result_a <= result_b {
|
||||
self.add_calculation(Calculation::Add(result_a, result_b))
|
||||
} else {
|
||||
self.add_calculation(Calculation::Add(result_b, result_a))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Expression::Product(a, b) => {
|
||||
let result_a = self.add_expression(a);
|
||||
let result_b = self.add_expression(b);
|
||||
if result_a == ValueSource::Constant(0) || result_b == ValueSource::Constant(0) {
|
||||
ValueSource::Constant(0)
|
||||
} else if result_a == ValueSource::Constant(1) {
|
||||
result_b
|
||||
} else if result_b == ValueSource::Constant(1) {
|
||||
result_a
|
||||
} else if result_a <= result_b {
|
||||
self.add_calculation(Calculation::Mul(result_a, result_b))
|
||||
} else {
|
||||
self.add_calculation(Calculation::Mul(result_b, result_a))
|
||||
}
|
||||
}
|
||||
Expression::Scaled(a, f) => {
|
||||
if *f == C::ScalarExt::zero() {
|
||||
ValueSource::Constant(0)
|
||||
} else if *f == C::ScalarExt::one() {
|
||||
self.add_expression(a)
|
||||
} else {
|
||||
let cst = self.add_constant(f);
|
||||
let result_a = self.add_expression(a);
|
||||
self.add_calculation(Calculation::Mul(result_a, cst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate h poly
|
||||
pub(in crate::plonk) fn evaluate_h(
|
||||
&self,
|
||||
|
@ -473,7 +288,6 @@ impl<C: CurveAffine> Evaluator<C> {
|
|||
let rot_scale = 1 << (domain.extended_k() - domain.k());
|
||||
let fixed = &pk.fixed_cosets[..];
|
||||
let extended_omega = domain.get_extended_omega();
|
||||
let num_lookups = pk.vk.cs.lookups.len();
|
||||
let isize = size as i32;
|
||||
let one = C::ScalarExt::one();
|
||||
let l0 = &pk.l0;
|
||||
|
@ -482,76 +296,38 @@ impl<C: CurveAffine> Evaluator<C> {
|
|||
let p = &pk.vk.cs.permutation;
|
||||
|
||||
let mut values = domain.empty_extended();
|
||||
let mut lookup_values = vec![C::Scalar::zero(); size * num_lookups];
|
||||
|
||||
// Core expression evaluations
|
||||
let num_threads = multicore::current_num_threads();
|
||||
let mut table_values_box = ThreadBox::wrap(&mut lookup_values);
|
||||
for (((advice, instance), lookups), permutation) in advice
|
||||
.iter()
|
||||
.zip(instance.iter())
|
||||
.zip(lookups.iter())
|
||||
.zip(permutations.iter())
|
||||
{
|
||||
// Custom gates
|
||||
multicore::scope(|scope| {
|
||||
let chunk_size = (size + num_threads - 1) / num_threads;
|
||||
for (thread_idx, values) in values.chunks_mut(chunk_size).enumerate() {
|
||||
let start = thread_idx * chunk_size;
|
||||
scope.spawn(move |_| {
|
||||
let table_values = table_values_box.unwrap();
|
||||
let mut rotations = vec![0usize; self.rotations.len()];
|
||||
let mut intermediates: Vec<C::ScalarExt> =
|
||||
vec![C::ScalarExt::zero(); self.calculations.len()];
|
||||
let mut eval_data = self.custom_gates.instance();
|
||||
for (i, value) in values.iter_mut().enumerate() {
|
||||
let idx = start + i;
|
||||
|
||||
// All rotation index values
|
||||
for (rot_idx, rot) in self.rotations.iter().enumerate() {
|
||||
rotations[rot_idx] = get_rotation_idx(idx, *rot, rot_scale, isize);
|
||||
}
|
||||
|
||||
// All calculations, with cached intermediate results
|
||||
for (i_idx, calc) in self.calculations.iter().enumerate() {
|
||||
intermediates[i_idx] = calc.calculation.evaluate(
|
||||
&rotations,
|
||||
&self.constants,
|
||||
&intermediates,
|
||||
fixed,
|
||||
advice,
|
||||
instance,
|
||||
&beta,
|
||||
&gamma,
|
||||
&theta,
|
||||
);
|
||||
}
|
||||
|
||||
// Accumulate value parts
|
||||
for value_part in self.value_parts.iter() {
|
||||
*value = *value * y
|
||||
+ value_part.get(
|
||||
&rotations,
|
||||
&self.constants,
|
||||
&intermediates,
|
||||
fixed,
|
||||
advice,
|
||||
instance,
|
||||
);
|
||||
}
|
||||
|
||||
// Values required for the lookups
|
||||
for (t, table_result) in self.lookup_results.iter().enumerate() {
|
||||
table_values[t * size + idx] = table_result.evaluate(
|
||||
&rotations,
|
||||
&self.constants,
|
||||
&intermediates,
|
||||
fixed,
|
||||
advice,
|
||||
instance,
|
||||
&beta,
|
||||
&gamma,
|
||||
&theta,
|
||||
);
|
||||
}
|
||||
*value = self.custom_gates.evaluate(
|
||||
&mut eval_data,
|
||||
fixed,
|
||||
advice,
|
||||
instance,
|
||||
&beta,
|
||||
&gamma,
|
||||
&theta,
|
||||
&y,
|
||||
value,
|
||||
idx,
|
||||
rot_scale,
|
||||
isize,
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -655,11 +431,27 @@ impl<C: CurveAffine> Evaluator<C> {
|
|||
.coeff_to_extended(lookup.permuted_table_poly.clone());
|
||||
|
||||
// Lookup constraints
|
||||
let table = &lookup_values[n * size..(n + 1) * size];
|
||||
parallelize(&mut values, |values, start| {
|
||||
let lookup_evaluator = &self.lookups[n];
|
||||
let mut eval_data = lookup_evaluator.instance();
|
||||
for (i, value) in values.iter_mut().enumerate() {
|
||||
let idx = start + i;
|
||||
|
||||
let table_value = lookup_evaluator.evaluate(
|
||||
&mut eval_data,
|
||||
fixed,
|
||||
advice,
|
||||
instance,
|
||||
&beta,
|
||||
&gamma,
|
||||
&theta,
|
||||
&y,
|
||||
&C::ScalarExt::zero(),
|
||||
idx,
|
||||
rot_scale,
|
||||
isize,
|
||||
);
|
||||
|
||||
let r_next = get_rotation_idx(idx, 1, rot_scale, isize);
|
||||
let r_prev = get_rotation_idx(idx, -1, rot_scale, isize);
|
||||
|
||||
|
@ -679,7 +471,7 @@ impl<C: CurveAffine> Evaluator<C> {
|
|||
+ ((product_coset[r_next]
|
||||
* (permuted_input_coset[idx] + beta)
|
||||
* (permuted_table_coset[idx] + gamma)
|
||||
- product_coset[idx] * table[idx])
|
||||
- product_coset[idx] * table_value)
|
||||
* l_active_row[idx]);
|
||||
// Check that the first values in the permuted input expression and permuted
|
||||
// fixed expression are the same.
|
||||
|
@ -701,24 +493,232 @@ impl<C: CurveAffine> Evaluator<C> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct ThreadBox<T>(*mut T, usize);
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl<T> Send for ThreadBox<T> {}
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl<T> Sync for ThreadBox<T> {}
|
||||
impl<C: CurveAffine> Default for GraphEvaluator<C> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// Fixed positions to allow easy access
|
||||
constants: vec![
|
||||
C::ScalarExt::zero(),
|
||||
C::ScalarExt::one(),
|
||||
C::ScalarExt::from(2u64),
|
||||
],
|
||||
rotations: Vec::new(),
|
||||
calculations: Vec::new(),
|
||||
num_intermediates: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps a mutable slice so it can be passed into a thread without
|
||||
/// hard to fix borrow checks caused by difficult data access patterns.
|
||||
impl<T> ThreadBox<T> {
|
||||
fn wrap(data: &mut [T]) -> Self {
|
||||
Self(data.as_mut_ptr(), data.len())
|
||||
impl<C: CurveAffine> GraphEvaluator<C> {
|
||||
/// Adds a rotation
|
||||
fn add_rotation(&mut self, rotation: &Rotation) -> usize {
|
||||
let position = self.rotations.iter().position(|&c| c == rotation.0);
|
||||
match position {
|
||||
Some(pos) => pos,
|
||||
None => {
|
||||
self.rotations.push(rotation.0);
|
||||
self.rotations.len() - 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn unwrap(&mut self) -> &mut [T] {
|
||||
#[allow(unsafe_code)]
|
||||
unsafe {
|
||||
slice::from_raw_parts_mut(self.0, self.1)
|
||||
/// Adds a constant
|
||||
fn add_constant(&mut self, constant: &C::ScalarExt) -> ValueSource {
|
||||
let position = self.constants.iter().position(|&c| c == *constant);
|
||||
ValueSource::Constant(match position {
|
||||
Some(pos) => pos,
|
||||
None => {
|
||||
self.constants.push(*constant);
|
||||
self.constants.len() - 1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Adds a calculation.
|
||||
/// Currently does the simplest thing possible: just stores the
|
||||
/// resulting value so the result can be reused when that calculation
|
||||
/// is done multiple times.
|
||||
fn add_calculation(&mut self, calculation: Calculation) -> ValueSource {
|
||||
let existing_calculation = self
|
||||
.calculations
|
||||
.iter()
|
||||
.find(|c| c.calculation == calculation);
|
||||
match existing_calculation {
|
||||
Some(existing_calculation) => ValueSource::Intermediate(existing_calculation.target),
|
||||
None => {
|
||||
let target = self.num_intermediates;
|
||||
self.calculations.push(CalculationInfo {
|
||||
calculation,
|
||||
target,
|
||||
});
|
||||
self.num_intermediates += 1;
|
||||
ValueSource::Intermediate(target)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates an optimized evaluation for the expression
|
||||
fn add_expression(&mut self, expr: &Expression<C::ScalarExt>) -> ValueSource {
|
||||
match expr {
|
||||
Expression::Constant(scalar) => self.add_constant(scalar),
|
||||
Expression::Selector(_selector) => unreachable!(),
|
||||
Expression::Fixed {
|
||||
query_index: _,
|
||||
column_index,
|
||||
rotation,
|
||||
} => {
|
||||
let rot_idx = self.add_rotation(rotation);
|
||||
self.add_calculation(Calculation::Store(ValueSource::Fixed(
|
||||
*column_index,
|
||||
rot_idx,
|
||||
)))
|
||||
}
|
||||
Expression::Advice {
|
||||
query_index: _,
|
||||
column_index,
|
||||
rotation,
|
||||
} => {
|
||||
let rot_idx = self.add_rotation(rotation);
|
||||
self.add_calculation(Calculation::Store(ValueSource::Advice(
|
||||
*column_index,
|
||||
rot_idx,
|
||||
)))
|
||||
}
|
||||
Expression::Instance {
|
||||
query_index: _,
|
||||
column_index,
|
||||
rotation,
|
||||
} => {
|
||||
let rot_idx = self.add_rotation(rotation);
|
||||
self.add_calculation(Calculation::Store(ValueSource::Instance(
|
||||
*column_index,
|
||||
rot_idx,
|
||||
)))
|
||||
}
|
||||
Expression::Negated(a) => match **a {
|
||||
Expression::Constant(scalar) => self.add_constant(&-scalar),
|
||||
_ => {
|
||||
let result_a = self.add_expression(a);
|
||||
match result_a {
|
||||
ValueSource::Constant(0) => result_a,
|
||||
_ => self.add_calculation(Calculation::Negate(result_a)),
|
||||
}
|
||||
}
|
||||
},
|
||||
Expression::Sum(a, b) => {
|
||||
// Undo subtraction stored as a + (-b) in expressions
|
||||
match &**b {
|
||||
Expression::Negated(b_int) => {
|
||||
let result_a = self.add_expression(a);
|
||||
let result_b = self.add_expression(b_int);
|
||||
if result_a == ValueSource::Constant(0) {
|
||||
self.add_calculation(Calculation::Negate(result_b))
|
||||
} else if result_b == ValueSource::Constant(0) {
|
||||
result_a
|
||||
} else {
|
||||
self.add_calculation(Calculation::Sub(result_a, result_b))
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let result_a = self.add_expression(a);
|
||||
let result_b = self.add_expression(b);
|
||||
if result_a == ValueSource::Constant(0) {
|
||||
result_b
|
||||
} else if result_b == ValueSource::Constant(0) {
|
||||
result_a
|
||||
} else if result_a <= result_b {
|
||||
self.add_calculation(Calculation::Add(result_a, result_b))
|
||||
} else {
|
||||
self.add_calculation(Calculation::Add(result_b, result_a))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Expression::Product(a, b) => {
|
||||
let result_a = self.add_expression(a);
|
||||
let result_b = self.add_expression(b);
|
||||
if result_a == ValueSource::Constant(0) || result_b == ValueSource::Constant(0) {
|
||||
ValueSource::Constant(0)
|
||||
} else if result_a == ValueSource::Constant(1) {
|
||||
result_b
|
||||
} else if result_b == ValueSource::Constant(1) {
|
||||
result_a
|
||||
} else if result_a == ValueSource::Constant(2) {
|
||||
self.add_calculation(Calculation::Double(result_b))
|
||||
} else if result_b == ValueSource::Constant(2) {
|
||||
self.add_calculation(Calculation::Double(result_a))
|
||||
} else if result_a == result_b {
|
||||
self.add_calculation(Calculation::Square(result_a))
|
||||
} else if result_a <= result_b {
|
||||
self.add_calculation(Calculation::Mul(result_a, result_b))
|
||||
} else {
|
||||
self.add_calculation(Calculation::Mul(result_b, result_a))
|
||||
}
|
||||
}
|
||||
Expression::Scaled(a, f) => {
|
||||
if *f == C::ScalarExt::zero() {
|
||||
ValueSource::Constant(0)
|
||||
} else if *f == C::ScalarExt::one() {
|
||||
self.add_expression(a)
|
||||
} else {
|
||||
let cst = self.add_constant(f);
|
||||
let result_a = self.add_expression(a);
|
||||
self.add_calculation(Calculation::Mul(result_a, cst))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new evaluation structure
|
||||
pub fn instance(&self) -> EvaluationData<C> {
|
||||
EvaluationData {
|
||||
intermediates: vec![C::ScalarExt::zero(); self.num_intermediates],
|
||||
rotations: vec![0usize; self.rotations.len()],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn evaluate<B: Basis>(
|
||||
&self,
|
||||
data: &mut EvaluationData<C>,
|
||||
fixed: &[Polynomial<C::ScalarExt, B>],
|
||||
advice: &[Polynomial<C::ScalarExt, B>],
|
||||
instance: &[Polynomial<C::ScalarExt, B>],
|
||||
beta: &C::ScalarExt,
|
||||
gamma: &C::ScalarExt,
|
||||
theta: &C::ScalarExt,
|
||||
y: &C::ScalarExt,
|
||||
previous_value: &C::ScalarExt,
|
||||
idx: usize,
|
||||
rot_scale: i32,
|
||||
isize: i32,
|
||||
) -> C::ScalarExt {
|
||||
// All rotation index values
|
||||
for (rot_idx, rot) in self.rotations.iter().enumerate() {
|
||||
data.rotations[rot_idx] = get_rotation_idx(idx, *rot, rot_scale, isize);
|
||||
}
|
||||
|
||||
// All calculations, with cached intermediate results
|
||||
for calc in self.calculations.iter() {
|
||||
data.intermediates[calc.target] = calc.calculation.evaluate(
|
||||
&data.rotations,
|
||||
&self.constants,
|
||||
&data.intermediates,
|
||||
fixed,
|
||||
advice,
|
||||
instance,
|
||||
beta,
|
||||
gamma,
|
||||
theta,
|
||||
y,
|
||||
previous_value,
|
||||
);
|
||||
}
|
||||
|
||||
// Return the result of the last calculation (if any)
|
||||
if let Some(calc) = self.calculations.last() {
|
||||
data.intermediates[calc.target]
|
||||
} else {
|
||||
C::ScalarExt::zero()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue